mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-11 01:40:27 +00:00
Compare commits
23 Commits
erikj/drop
...
devon/lock
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
126f0c3587 | ||
|
|
e519ee230b | ||
|
|
ac1bf682ff | ||
|
|
a0b70473fc | ||
|
|
95a85b1129 | ||
|
|
3d8535b1de | ||
|
|
628351b98d | ||
|
|
8f27b3af07 | ||
|
|
579f4ac1cd | ||
|
|
c53999dab8 | ||
|
|
b41a9ebb38 | ||
|
|
6ec5e13ec9 | ||
|
|
148e93576e | ||
|
|
56ed412839 | ||
|
|
9c5d08fff8 | ||
|
|
90a6bd01c2 | ||
|
|
aa07a01452 | ||
|
|
8364c01a2b | ||
|
|
e27808f306 | ||
|
|
048c1ac7f6 | ||
|
|
ca290d325c | ||
|
|
0a31cf18cd | ||
|
|
48db0c2d6c |
2
.github/workflows/docs-pr-netlify.yaml
vendored
2
.github/workflows/docs-pr-netlify.yaml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
# There's a 'download artifact' action, but it hasn't been updated for the workflow_run action
|
||||
# (https://github.com/actions/download-artifact/issues/60) so instead we get this mess:
|
||||
- name: 📥 Download artifact
|
||||
uses: dawidd6/action-download-artifact@80620a5d27ce0ae443b965134db88467fc607b43 # v7
|
||||
uses: dawidd6/action-download-artifact@20319c5641d495c8a52e688b7dc5fada6c3a9fbc # v8
|
||||
with:
|
||||
workflow: docs-pr.yaml
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
|
||||
44
CHANGES.md
44
CHANGES.md
@@ -1,3 +1,47 @@
|
||||
# Synapse 1.123.0 (2025-01-28)
|
||||
|
||||
No significant changes since 1.123.0rc1.
|
||||
|
||||
|
||||
|
||||
|
||||
# Synapse 1.123.0rc1 (2025-01-21)
|
||||
|
||||
### Features
|
||||
|
||||
- Implement [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) for custom profile fields. Contributed by @clokep. ([\#17488](https://github.com/element-hq/synapse/issues/17488))
|
||||
- Add a query parameter `type` to the [Room State Admin API](https://element-hq.github.io/synapse/develop/admin_api/rooms.html#room-state-api) that filters the state event. ([\#18035](https://github.com/element-hq/synapse/issues/18035))
|
||||
- Support the new `/auth_metadata` endpoint defined in [MSC2965](https://github.com/matrix-org/matrix-spec-proposals/pull/2965). ([\#18093](https://github.com/element-hq/synapse/issues/18093))
|
||||
|
||||
### Bugfixes
|
||||
|
||||
- Fix membership caches not updating in state reset scenarios. ([\#17732](https://github.com/element-hq/synapse/issues/17732))
|
||||
- Fix rare race where on upgrade to v1.122.0 a long running database upgrade could lock out new events from being received or sent. ([\#18091](https://github.com/element-hq/synapse/issues/18091))
|
||||
|
||||
### Improved Documentation
|
||||
|
||||
- Document `tls` option for a worker instance in `instance_map`. ([\#18064](https://github.com/element-hq/synapse/issues/18064))
|
||||
|
||||
### Deprecations and Removals
|
||||
|
||||
- Remove the unstable [MSC4151](https://github.com/matrix-org/matrix-spec-proposals/pull/4151) implementation. The stable support remains, per [Matrix 1.13](https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3roomsroomidreport). ([\#18052](https://github.com/element-hq/synapse/issues/18052))
|
||||
|
||||
### Internal Changes
|
||||
|
||||
- Increase invite rate limits (`rc_invites.per_issuer`) for Complement. ([\#18072](https://github.com/element-hq/synapse/issues/18072))
|
||||
|
||||
|
||||
|
||||
### Updates to locked dependencies
|
||||
|
||||
* Bump jinja2 from 3.1.4 to 3.1.5. ([\#18067](https://github.com/element-hq/synapse/issues/18067))
|
||||
* Bump mypy from 1.12.1 to 1.13.0. ([\#18083](https://github.com/element-hq/synapse/issues/18083))
|
||||
* Bump pillow from 11.0.0 to 11.1.0. ([\#18084](https://github.com/element-hq/synapse/issues/18084))
|
||||
* Bump pyo3 from 0.23.3 to 0.23.4. ([\#18079](https://github.com/element-hq/synapse/issues/18079))
|
||||
* Bump pyopenssl from 24.2.1 to 24.3.0. ([\#18062](https://github.com/element-hq/synapse/issues/18062))
|
||||
* Bump serde_json from 1.0.134 to 1.0.135. ([\#18081](https://github.com/element-hq/synapse/issues/18081))
|
||||
* Bump ulid from 1.1.3 to 1.1.4. ([\#18080](https://github.com/element-hq/synapse/issues/18080))
|
||||
|
||||
# Synapse 1.122.0 (2025-01-14)
|
||||
|
||||
Please note that this version of Synapse drops support for PostgreSQL 11 and 12. The minimum version of PostgreSQL supported is now version 13.
|
||||
|
||||
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -216,9 +216,9 @@ checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
version = "0.4.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
@@ -449,9 +449,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.135"
|
||||
version = "1.0.137"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9"
|
||||
checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Fix membership caches not updating in state reset scenarios.
|
||||
1
changelog.d/18000.bugfix
Normal file
1
changelog.d/18000.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Add rate limit `rc_presence.per_user`. This prevents load from excessive presence updates sent by clients via sync api. Also rate limit `/_matrix/client/v3/presence` as per the spec. Contributed by @rda0.
|
||||
@@ -1 +0,0 @@
|
||||
Add a unit test for the `type` parameter of the [Room State Admin API](https://element-hq.github.io/synapse/develop/admin_api/rooms.html#room-state-api).
|
||||
@@ -1 +0,0 @@
|
||||
Remove the unstable [MSC4151](https://github.com/matrix-org/matrix-spec-proposals/pull/4151) implementation. The stable support remains, per [Matrix 1.13](https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3roomsroomidreport).
|
||||
@@ -1 +0,0 @@
|
||||
Increase invite rate limits (`rc_invites.per_issuer`) for Complement.
|
||||
1
changelog.d/18073.bugfix
Normal file
1
changelog.d/18073.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Deactivated users will no longer automatically accept an invite when `auto_accept_invites` is enabled.
|
||||
1
changelog.d/18075.bugfix
Normal file
1
changelog.d/18075.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix join being denied after being invited over federation. Also fixes other out-of-band membership transitions.
|
||||
2
changelog.d/18089.bugfix
Normal file
2
changelog.d/18089.bugfix
Normal file
@@ -0,0 +1,2 @@
|
||||
Updates contributed `docker-compose.yml` file to PostgreSQL v15, as v12 is no longer supported by Synapse.
|
||||
Contributed by @maxkratz.
|
||||
1
changelog.d/18109.misc
Normal file
1
changelog.d/18109.misc
Normal file
@@ -0,0 +1 @@
|
||||
Increase the length of the generated `nonce` parameter when perfoming OIDC logins to comply with the TI-Messenger spec.
|
||||
1
changelog.d/18112.bugfix
Normal file
1
changelog.d/18112.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Raise an error if someone is using an incorrect suffix in a config duration string.
|
||||
1
changelog.d/18119.bugfix
Normal file
1
changelog.d/18119.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug where the [Delete Room Admin API](https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-2-new-version) would fail if the `block` parameter was set to `true` and a worker other than the main process was configured to handle background tasks.
|
||||
1
changelog.d/18124.misc
Normal file
1
changelog.d/18124.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add log message when worker lock timeouts get large.
|
||||
@@ -51,7 +51,7 @@ services:
|
||||
- traefik.http.routers.https-synapse.tls.certResolver=le-ssl
|
||||
|
||||
db:
|
||||
image: docker.io/postgres:12-alpine
|
||||
image: docker.io/postgres:15-alpine
|
||||
# Change that password, of course!
|
||||
environment:
|
||||
- POSTGRES_USER=synapse
|
||||
|
||||
12
debian/changelog
vendored
12
debian/changelog
vendored
@@ -1,3 +1,15 @@
|
||||
matrix-synapse-py3 (1.123.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.123.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 28 Jan 2025 08:37:34 -0700
|
||||
|
||||
matrix-synapse-py3 (1.123.0~rc1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.123.0rc1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 21 Jan 2025 14:39:57 +0100
|
||||
|
||||
matrix-synapse-py3 (1.122.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.122.0.
|
||||
|
||||
@@ -89,6 +89,11 @@ rc_invites:
|
||||
per_second: 1000
|
||||
burst_count: 1000
|
||||
|
||||
rc_presence:
|
||||
per_user:
|
||||
per_second: 9999
|
||||
burst_count: 9999
|
||||
|
||||
federation_rr_transactions_per_room_per_second: 9999
|
||||
|
||||
allow_device_name_lookup_over_federation: true
|
||||
|
||||
@@ -1868,6 +1868,27 @@ rc_federation:
|
||||
concurrent: 5
|
||||
```
|
||||
---
|
||||
### `rc_presence`
|
||||
|
||||
This option sets ratelimiting for presence.
|
||||
|
||||
The `rc_presence.per_user` option sets rate limits on how often a specific
|
||||
users' presence updates are evaluated. Ratelimited presence updates sent via sync are
|
||||
ignored, and no error is returned to the client.
|
||||
This option also sets the rate limit for the
|
||||
[`PUT /_matrix/client/v3/presence/{userId}/status`](https://spec.matrix.org/latest/client-server-api/#put_matrixclientv3presenceuseridstatus)
|
||||
endpoint.
|
||||
|
||||
`per_user` defaults to `per_second: 0.1`, `burst_count: 1`.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
rc_presence:
|
||||
per_user:
|
||||
per_second: 0.05
|
||||
burst_count: 0.5
|
||||
```
|
||||
---
|
||||
### `federation_rr_transactions_per_room_per_second`
|
||||
|
||||
Sets outgoing federation transaction frequency for sending read-receipts,
|
||||
@@ -4465,6 +4486,10 @@ instance_map:
|
||||
worker1:
|
||||
host: localhost
|
||||
port: 8034
|
||||
other:
|
||||
host: localhost
|
||||
port: 8035
|
||||
tls: true
|
||||
```
|
||||
Example configuration(#2, for UNIX sockets):
|
||||
```yaml
|
||||
|
||||
12
poetry.lock
generated
12
poetry.lock
generated
@@ -1960,13 +1960,13 @@ six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.18"
|
||||
version = "0.0.20"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "python_multipart-0.0.18-py3-none-any.whl", hash = "sha256:efe91480f485f6a361427a541db4796f9e1591afc0fb8e7a4ba06bfbc6708996"},
|
||||
{file = "python_multipart-0.0.18.tar.gz", hash = "sha256:7a68db60c8bfb82e460637fa4750727b45af1d5e2ed215593f917f64694d34fe"},
|
||||
{file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"},
|
||||
{file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2706,13 +2706,13 @@ twisted = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-bleach"
|
||||
version = "6.1.0.20240331"
|
||||
version = "6.2.0.20241123"
|
||||
description = "Typing stubs for bleach"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "types-bleach-6.1.0.20240331.tar.gz", hash = "sha256:2ee858a84fb06fc2225ff56ba2f7f6c88b65638659efae0d7bfd6b24a1b5a524"},
|
||||
{file = "types_bleach-6.1.0.20240331-py3-none-any.whl", hash = "sha256:399bc59bfd20a36a56595f13f805e56c8a08e5a5c07903e5cf6fafb5a5107dd4"},
|
||||
{file = "types_bleach-6.2.0.20241123-py3-none-any.whl", hash = "sha256:c6e58b3646665ca7c6b29890375390f4569e84f0cf5c171e0fe1ddb71a7be86a"},
|
||||
{file = "types_bleach-6.2.0.20241123.tar.gz", hash = "sha256:dac5fe9015173514da3ac810c1a935619a3ccbcc5d66c4cbf4707eac00539057"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -97,7 +97,7 @@ module-name = "synapse.synapse_rust"
|
||||
|
||||
[tool.poetry]
|
||||
name = "matrix-synapse"
|
||||
version = "1.122.0"
|
||||
version = "1.123.0"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
@@ -174,6 +174,12 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
logger.warning("Failed to load metadata:", exc_info=True)
|
||||
return None
|
||||
|
||||
async def auth_metadata(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the auth metadata dict
|
||||
"""
|
||||
return await self._issuer_metadata.get()
|
||||
|
||||
async def _introspection_endpoint(self) -> str:
|
||||
"""
|
||||
Returns the introspection endpoint of the issuer
|
||||
|
||||
@@ -132,6 +132,10 @@ class Codes(str, Enum):
|
||||
# connection.
|
||||
UNKNOWN_POS = "M_UNKNOWN_POS"
|
||||
|
||||
# Part of MSC4133
|
||||
PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE"
|
||||
KEY_TOO_LARGE = "M_KEY_TOO_LARGE"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code, a message string attributes and optional headers.
|
||||
|
||||
@@ -275,6 +275,7 @@ class Ratelimiter:
|
||||
update: bool = True,
|
||||
n_actions: int = 1,
|
||||
_time_now_s: Optional[float] = None,
|
||||
pause: Optional[float] = 0.5,
|
||||
) -> None:
|
||||
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
||||
|
||||
@@ -298,6 +299,8 @@ class Ratelimiter:
|
||||
at all.
|
||||
_time_now_s: The current time. Optional, defaults to the current time according
|
||||
to self.clock. Only used by tests.
|
||||
pause: Time in seconds to pause when an action is being limited. Defaults to 0.5
|
||||
to stop clients from "tight-looping" on retrying their request.
|
||||
|
||||
Raises:
|
||||
LimitExceededError: If an action could not be performed, along with the time in
|
||||
@@ -316,9 +319,8 @@ class Ratelimiter:
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
# We pause for a bit here to stop clients from "tight-looping" on
|
||||
# retrying their request.
|
||||
await self.clock.sleep(0.5)
|
||||
if pause:
|
||||
await self.clock.sleep(pause)
|
||||
|
||||
raise LimitExceededError(
|
||||
limiter_name=self._limiter_name,
|
||||
|
||||
@@ -221,9 +221,13 @@ class Config:
|
||||
The number of milliseconds in the duration.
|
||||
|
||||
Raises:
|
||||
TypeError, if given something other than an integer or a string
|
||||
TypeError: if given something other than an integer or a string, or the
|
||||
duration is using an incorrect suffix.
|
||||
ValueError: if given a string not of the form described above.
|
||||
"""
|
||||
# For integers, we prefer to use `type(value) is int` instead of
|
||||
# `isinstance(value, int)` because we want to exclude subclasses of int, such as
|
||||
# bool.
|
||||
if type(value) is int: # noqa: E721
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
@@ -246,9 +250,20 @@ class Config:
|
||||
if suffix in sizes:
|
||||
value = value[:-1]
|
||||
size = sizes[suffix]
|
||||
elif suffix.isdigit():
|
||||
# No suffix is treated as milliseconds.
|
||||
value = value
|
||||
size = 1
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Bad duration suffix {value} (expected no suffix or one of these suffixes: {sizes.keys()})"
|
||||
)
|
||||
|
||||
return int(value) * size
|
||||
else:
|
||||
raise TypeError(f"Bad duration {value!r}")
|
||||
raise TypeError(
|
||||
f"Bad duration type {value!r} (expected int or string duration)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def abspath(file_path: str) -> str:
|
||||
|
||||
@@ -436,6 +436,9 @@ class ExperimentalConfig(Config):
|
||||
("experimental", "msc4108_delegation_endpoint"),
|
||||
)
|
||||
|
||||
# MSC4133: Custom profile fields
|
||||
self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False)
|
||||
|
||||
# MSC4210: Remove legacy mentions
|
||||
self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False)
|
||||
|
||||
|
||||
@@ -228,3 +228,9 @@ class RatelimitConfig(Config):
|
||||
config.get("remote_media_download_burst_count", "500M")
|
||||
),
|
||||
)
|
||||
|
||||
self.rc_presence_per_user = RatelimitSettings.parse(
|
||||
config,
|
||||
"rc_presence.per_user",
|
||||
defaults={"per_second": 0.1, "burst_count": 1},
|
||||
)
|
||||
|
||||
@@ -566,6 +566,7 @@ def _is_membership_change_allowed(
|
||||
logger.debug(
|
||||
"_is_membership_change_allowed: %s",
|
||||
{
|
||||
"caller_membership": caller.membership if caller else None,
|
||||
"caller_in_room": caller_in_room,
|
||||
"caller_invited": caller_invited,
|
||||
"caller_knocked": caller_knocked,
|
||||
@@ -677,7 +678,8 @@ def _is_membership_change_allowed(
|
||||
and join_rule == JoinRules.KNOCK_RESTRICTED
|
||||
)
|
||||
):
|
||||
if not caller_in_room and not caller_invited:
|
||||
# You can only join the room if you are invited or are already in the room.
|
||||
if not (caller_in_room or caller_invited):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
else:
|
||||
# TODO (erikj): may_join list
|
||||
|
||||
@@ -42,7 +42,7 @@ import attr
|
||||
from typing_extensions import Literal
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
|
||||
from synapse.synapse_rust.events import EventInternalMetadata
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
@@ -325,12 +325,17 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
def __repr__(self) -> str:
|
||||
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
|
||||
|
||||
conditional_membership_string = ""
|
||||
if self.get("type") == EventTypes.Member:
|
||||
conditional_membership_string = f"membership={self.membership}, "
|
||||
|
||||
return (
|
||||
f"<{self.__class__.__name__} "
|
||||
f"{rejection}"
|
||||
f"event_id={self.event_id}, "
|
||||
f"type={self.get('type')}, "
|
||||
f"state_key={self.get('state_key')}, "
|
||||
f"{conditional_membership_string}"
|
||||
f"outlier={self.internal_metadata.is_outlier()}"
|
||||
">"
|
||||
)
|
||||
|
||||
@@ -66,50 +66,67 @@ class InviteAutoAccepter:
|
||||
event: The incoming event.
|
||||
"""
|
||||
# Check if the event is an invite for a local user.
|
||||
is_invite_for_local_user = (
|
||||
event.type == EventTypes.Member
|
||||
and event.is_state()
|
||||
and event.membership == Membership.INVITE
|
||||
and self._api.is_mine(event.state_key)
|
||||
)
|
||||
if (
|
||||
event.type != EventTypes.Member
|
||||
or event.is_state() is False
|
||||
or event.membership != Membership.INVITE
|
||||
or self._api.is_mine(event.state_key) is False
|
||||
):
|
||||
return
|
||||
|
||||
# Only accept invites for direct messages if the configuration mandates it.
|
||||
is_direct_message = event.content.get("is_direct", False)
|
||||
is_allowed_by_direct_message_rules = (
|
||||
not self._config.accept_invites_only_for_direct_messages
|
||||
or is_direct_message is True
|
||||
)
|
||||
if (
|
||||
self._config.accept_invites_only_for_direct_messages
|
||||
and is_direct_message is False
|
||||
):
|
||||
return
|
||||
|
||||
# Only accept invites from remote users if the configuration mandates it.
|
||||
is_from_local_user = self._api.is_mine(event.sender)
|
||||
is_allowed_by_local_user_rules = (
|
||||
not self._config.accept_invites_only_from_local_users
|
||||
or is_from_local_user is True
|
||||
if (
|
||||
self._config.accept_invites_only_from_local_users
|
||||
and is_from_local_user is False
|
||||
):
|
||||
return
|
||||
|
||||
# Check the user is activated.
|
||||
recipient = await self._api.get_userinfo_by_id(event.state_key)
|
||||
|
||||
# Ignore if the user doesn't exist.
|
||||
if recipient is None:
|
||||
return
|
||||
|
||||
# Never accept invites for deactivated users.
|
||||
if recipient.is_deactivated:
|
||||
return
|
||||
|
||||
# Never accept invites for suspended users.
|
||||
if recipient.suspended:
|
||||
return
|
||||
|
||||
# Never accept invites for locked users.
|
||||
if recipient.locked:
|
||||
return
|
||||
|
||||
# Make the user join the room. We run this as a background process to circumvent a race condition
|
||||
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
|
||||
run_as_background_process(
|
||||
"retry_make_join",
|
||||
self._retry_make_join,
|
||||
event.state_key,
|
||||
event.state_key,
|
||||
event.room_id,
|
||||
"join",
|
||||
bg_start_span=False,
|
||||
)
|
||||
|
||||
if (
|
||||
is_invite_for_local_user
|
||||
and is_allowed_by_direct_message_rules
|
||||
and is_allowed_by_local_user_rules
|
||||
):
|
||||
# Make the user join the room. We run this as a background process to circumvent a race condition
|
||||
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
|
||||
run_as_background_process(
|
||||
"retry_make_join",
|
||||
self._retry_make_join,
|
||||
event.state_key,
|
||||
event.state_key,
|
||||
event.room_id,
|
||||
"join",
|
||||
bg_start_span=False,
|
||||
if is_direct_message:
|
||||
# Mark this room as a direct message!
|
||||
await self._mark_room_as_direct_message(
|
||||
event.state_key, event.sender, event.room_id
|
||||
)
|
||||
|
||||
if is_direct_message:
|
||||
# Mark this room as a direct message!
|
||||
await self._mark_room_as_direct_message(
|
||||
event.state_key, event.sender, event.room_id
|
||||
)
|
||||
|
||||
async def _mark_room_as_direct_message(
|
||||
self, user_id: str, dm_user_id: str, room_id: str
|
||||
) -> None:
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
import attr
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_EVENT_FORMAT_VERSIONS,
|
||||
EventFormatVersions,
|
||||
@@ -109,6 +109,19 @@ class EventBuilder:
|
||||
def is_state(self) -> bool:
|
||||
return self._state_key is not None
|
||||
|
||||
def is_mine_id(self, user_id: str) -> bool:
|
||||
"""Determines whether a user ID or room alias originates from this homeserver.
|
||||
|
||||
Returns:
|
||||
`True` if the hostname part of the user ID or room alias matches this
|
||||
homeserver.
|
||||
`False` otherwise, or if the user ID or room alias is malformed.
|
||||
"""
|
||||
localpart_hostname = user_id.split(":", 1)
|
||||
if len(localpart_hostname) < 2:
|
||||
return False
|
||||
return localpart_hostname[1] == self._hostname
|
||||
|
||||
async def build(
|
||||
self,
|
||||
prev_event_ids: List[str],
|
||||
@@ -142,6 +155,46 @@ class EventBuilder:
|
||||
self, state_ids
|
||||
)
|
||||
|
||||
# Check for out-of-band membership that may have been exposed on `/sync` but
|
||||
# the events have not been de-outliered yet so they won't be part of the
|
||||
# room state yet.
|
||||
#
|
||||
# This helps in situations where a remote homeserver invites a local user to
|
||||
# a room that we're already participating in; and we've persisted the invite
|
||||
# as an out-of-band membership (outlier), but it hasn't been pushed to us as
|
||||
# part of a `/send` transaction yet and de-outliered. This also helps for
|
||||
# any of the other out-of-band membership transitions.
|
||||
#
|
||||
# As an optimization, we could check if the room state already includes a
|
||||
# non-`leave` membership event, then we can assume the membership event has
|
||||
# been de-outliered and we don't need to check for an out-of-band
|
||||
# membership. But we don't have the necessary information from a
|
||||
# `StateMap[str]` and we'll just have to take the hit of this extra lookup
|
||||
# for any membership event for now.
|
||||
if self.type == EventTypes.Member and self.is_mine_id(self.state_key):
|
||||
(
|
||||
_membership,
|
||||
member_event_id,
|
||||
) = await self._store.get_local_current_membership_for_user_in_room(
|
||||
user_id=self.state_key,
|
||||
room_id=self.room_id,
|
||||
)
|
||||
# There is no need to check if the membership is actually an
|
||||
# out-of-band membership (`outlier`) as we would end up with the
|
||||
# same result either way (adding the member event to the
|
||||
# `auth_event_ids`).
|
||||
if (
|
||||
member_event_id is not None
|
||||
# We only need to be careful about duplicating the event in the
|
||||
# `auth_event_ids` list (duplicate `type`/`state_key` is part of the
|
||||
# authorization rules)
|
||||
and member_event_id not in auth_event_ids
|
||||
):
|
||||
auth_event_ids.append(member_event_id)
|
||||
# Also make sure to point to the previous membership event that will
|
||||
# allow this one to happen so the computed state works out.
|
||||
prev_event_ids.append(member_event_id)
|
||||
|
||||
format_version = self.room_version.event_format
|
||||
# The types of auth/prev events changes between event versions.
|
||||
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
|
||||
|
||||
@@ -2272,8 +2272,9 @@ class FederationEventHandler:
|
||||
event_and_contexts, backfilled=backfilled
|
||||
)
|
||||
|
||||
# After persistence we always need to notify replication there may
|
||||
# be new data.
|
||||
# After persistence, we never notify clients (wake up `/sync` streams) about
|
||||
# backfilled events but it's important to let all the workers know about any
|
||||
# new event (backfilled or not) because TODO
|
||||
self._notifier.notify_replication()
|
||||
|
||||
if self._ephemeral_messages_enabled:
|
||||
|
||||
@@ -1002,7 +1002,21 @@ class OidcProvider:
|
||||
"""
|
||||
|
||||
state = generate_token()
|
||||
nonce = generate_token()
|
||||
|
||||
# Generate a nonce 32 characters long. When encoded with base64url later on,
|
||||
# the nonce will be 43 characters when sent to the identity provider.
|
||||
#
|
||||
# While RFC7636 does not specify a minimum length for the `nonce`
|
||||
# parameter, the TI-Messenger IDP_FD spec v1.7.3 does require it to be
|
||||
# between 43 and 128 characters. This spec concerns using Matrix for
|
||||
# communication in German healthcare.
|
||||
#
|
||||
# As increasing the length only strengthens security, we use this length
|
||||
# to allow TI-Messenger deployments using Synapse to satisfy this
|
||||
# external spec.
|
||||
#
|
||||
# See https://github.com/element-hq/synapse/pull/18109 for more context.
|
||||
nonce = generate_token(length=32)
|
||||
code_verifier = ""
|
||||
|
||||
if not client_redirect_url:
|
||||
|
||||
@@ -32,7 +32,7 @@ from synapse.api.errors import (
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
@@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_DISPLAYNAME_LEN = 256
|
||||
MAX_AVATAR_URL_LEN = 1000
|
||||
# Field name length is specced at 255 bytes.
|
||||
MAX_CUSTOM_FIELD_LEN = 255
|
||||
|
||||
|
||||
class ProfileHandler:
|
||||
@@ -90,7 +92,15 @@ class ProfileHandler:
|
||||
|
||||
if self.hs.is_mine(target_user):
|
||||
profileinfo = await self.store.get_profileinfo(target_user)
|
||||
if profileinfo.display_name is None and profileinfo.avatar_url is None:
|
||||
extra_fields = {}
|
||||
if self.hs.config.experimental.msc4133_enabled:
|
||||
extra_fields = await self.store.get_profile_fields(target_user)
|
||||
|
||||
if (
|
||||
profileinfo.display_name is None
|
||||
and profileinfo.avatar_url is None
|
||||
and not extra_fields
|
||||
):
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
# Do not include display name or avatar if unset.
|
||||
@@ -99,6 +109,9 @@ class ProfileHandler:
|
||||
ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name
|
||||
if profileinfo.avatar_url is not None:
|
||||
ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url
|
||||
if extra_fields:
|
||||
ret.update(extra_fields)
|
||||
|
||||
return ret
|
||||
else:
|
||||
try:
|
||||
@@ -403,6 +416,110 @@ class ProfileHandler:
|
||||
|
||||
return True
|
||||
|
||||
async def get_profile_field(
|
||||
self, target_user: UserID, field_name: str
|
||||
) -> JsonValue:
|
||||
"""
|
||||
Fetch a user's profile from the database for local users and over federation
|
||||
for remote users.
|
||||
|
||||
Args:
|
||||
target_user: The user ID to fetch the profile for.
|
||||
field_name: The field to fetch the profile for.
|
||||
|
||||
Returns:
|
||||
The value for the profile field or None if the field does not exist.
|
||||
"""
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
field_value = await self.store.get_profile_field(
|
||||
target_user, field_name
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
raise
|
||||
|
||||
return field_value
|
||||
else:
|
||||
try:
|
||||
result = await self.federation.make_query(
|
||||
destination=target_user.domain,
|
||||
query_type="profile",
|
||||
args={"user_id": target_user.to_string(), "field": field_name},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except RequestSendFailed as e:
|
||||
raise SynapseError(502, "Failed to fetch profile") from e
|
||||
except HttpResponseException as e:
|
||||
raise e.to_synapse_error()
|
||||
|
||||
return result.get(field_name)
|
||||
|
||||
async def set_profile_field(
|
||||
self,
|
||||
target_user: UserID,
|
||||
requester: Requester,
|
||||
field_name: str,
|
||||
new_value: JsonValue,
|
||||
by_admin: bool = False,
|
||||
deactivation: bool = False,
|
||||
) -> None:
|
||||
"""Set a new profile field for a user.
|
||||
|
||||
Args:
|
||||
target_user: the user whose profile is to be changed.
|
||||
requester: The user attempting to make this change.
|
||||
field_name: The name of the profile field to update.
|
||||
new_value: The new field value for this user.
|
||||
by_admin: Whether this change was made by an administrator.
|
||||
deactivation: Whether this change was made while deactivating the user.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(403, "Cannot set another user's profile")
|
||||
|
||||
await self.store.set_profile_field(target_user, field_name, new_value)
|
||||
|
||||
# Custom fields do not propagate into the user directory *or* rooms.
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self._third_party_rules.on_profile_update(
|
||||
target_user.to_string(), profile, by_admin, deactivation
|
||||
)
|
||||
|
||||
async def delete_profile_field(
|
||||
self,
|
||||
target_user: UserID,
|
||||
requester: Requester,
|
||||
field_name: str,
|
||||
by_admin: bool = False,
|
||||
deactivation: bool = False,
|
||||
) -> None:
|
||||
"""Delete a field from a user's profile.
|
||||
|
||||
Args:
|
||||
target_user: the user whose profile is to be changed.
|
||||
requester: The user attempting to make this change.
|
||||
field_name: The name of the profile field to remove.
|
||||
by_admin: Whether this change was made by an administrator.
|
||||
deactivation: Whether this change was made while deactivating the user.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's profile")
|
||||
|
||||
await self.store.delete_profile_field(target_user, field_name)
|
||||
|
||||
# Custom fields do not propagate into the user directory *or* rooms.
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self._third_party_rules.on_profile_update(
|
||||
target_user.to_string(), profile, by_admin, deactivation
|
||||
)
|
||||
|
||||
async def on_profile_query(self, args: JsonDict) -> JsonDict:
|
||||
"""Handles federation profile query requests."""
|
||||
|
||||
@@ -419,13 +536,24 @@ class ProfileHandler:
|
||||
|
||||
just_field = args.get("field", None)
|
||||
|
||||
response = {}
|
||||
response: JsonDict = {}
|
||||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
if just_field is None or just_field == ProfileFields.DISPLAYNAME:
|
||||
response["displayname"] = await self.store.get_profile_displayname(user)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
if just_field is None or just_field == ProfileFields.AVATAR_URL:
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
|
||||
|
||||
if self.hs.config.experimental.msc4133_enabled:
|
||||
if just_field is None:
|
||||
response.update(await self.store.get_profile_fields(user))
|
||||
elif just_field not in (
|
||||
ProfileFields.DISPLAYNAME,
|
||||
ProfileFields.AVATAR_URL,
|
||||
):
|
||||
response[just_field] = await self.store.get_profile_field(
|
||||
user, just_field
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
import random
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
@@ -269,6 +270,10 @@ class WaitingLock:
|
||||
def _get_next_retry_interval(self) -> float:
|
||||
next = self._retry_interval
|
||||
self._retry_interval = max(5, next * 2)
|
||||
if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes
|
||||
logging.warning(
|
||||
f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock."
|
||||
)
|
||||
return next * random.uniform(0.9, 1.1)
|
||||
|
||||
|
||||
@@ -344,4 +349,8 @@ class WaitingMultiLock:
|
||||
def _get_next_retry_interval(self) -> float:
|
||||
next = self._retry_interval
|
||||
self._retry_interval = max(5, next * 2)
|
||||
if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes
|
||||
logging.warning(
|
||||
f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock."
|
||||
)
|
||||
return next * random.uniform(0.9, 1.1)
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.rest.client import (
|
||||
account_validity,
|
||||
appservice_ping,
|
||||
auth,
|
||||
auth_issuer,
|
||||
auth_metadata,
|
||||
capabilities,
|
||||
delayed_events,
|
||||
devices,
|
||||
@@ -121,7 +121,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = (
|
||||
mutual_rooms.register_servlets,
|
||||
login_token_request.register_servlets,
|
||||
rendezvous.register_servlets,
|
||||
auth_issuer.register_servlets,
|
||||
auth_metadata.register_servlets,
|
||||
)
|
||||
|
||||
SERVLET_GROUPS: Dict[str, Iterable[RegisterServletsFunc]] = {
|
||||
@@ -187,7 +187,7 @@ class ClientRestResource(JsonResource):
|
||||
mutual_rooms.register_servlets,
|
||||
login_token_request.register_servlets,
|
||||
rendezvous.register_servlets,
|
||||
auth_issuer.register_servlets,
|
||||
auth_metadata.register_servlets,
|
||||
]:
|
||||
continue
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ logger = logging.getLogger(__name__)
|
||||
class AuthIssuerServlet(RestServlet):
|
||||
"""
|
||||
Advertises what OpenID Connect issuer clients should use to authorise users.
|
||||
This endpoint was defined in a previous iteration of MSC2965, and is still
|
||||
used by some clients.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
@@ -63,7 +65,42 @@ class AuthIssuerServlet(RestServlet):
|
||||
)
|
||||
|
||||
|
||||
class AuthMetadataServlet(RestServlet):
|
||||
"""
|
||||
Advertises the OAuth 2.0 server metadata for the homeserver.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc2965/auth_metadata$",
|
||||
unstable=True,
|
||||
releases=(),
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._config = hs.config
|
||||
self._auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
if self._config.experimental.msc3861.enabled:
|
||||
# If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth
|
||||
# We import lazily here because of the authlib requirement
|
||||
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
|
||||
|
||||
auth = cast(MSC3861DelegatedAuth, self._auth)
|
||||
return 200, await auth.auth_metadata()
|
||||
else:
|
||||
# Wouldn't expect this to be reached: the servlet shouldn't have been
|
||||
# registered. Still, fail gracefully if we are registered for some reason.
|
||||
raise SynapseError(
|
||||
404,
|
||||
"OIDC discovery has not been configured on this homeserver",
|
||||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# We use the MSC3861 values as they are used by multiple MSCs
|
||||
if hs.config.experimental.msc3861.enabled:
|
||||
AuthIssuerServlet(hs).register(http_server)
|
||||
AuthMetadataServlet(hs).register(http_server)
|
||||
@@ -92,6 +92,23 @@ class CapabilitiesRestServlet(RestServlet):
|
||||
"enabled": self.config.experimental.msc3664_enabled,
|
||||
}
|
||||
|
||||
if self.config.experimental.msc4133_enabled:
|
||||
response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = {
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Ensure this is consistent with the legacy m.set_displayname and
|
||||
# m.set_avatar_url.
|
||||
disallowed = []
|
||||
if not self.config.registration.enable_set_displayname:
|
||||
disallowed.append("displayname")
|
||||
if not self.config.registration.enable_set_avatar_url:
|
||||
disallowed.append("avatar_url")
|
||||
if disallowed:
|
||||
response["capabilities"]["uk.tcpip.msc4133.profile_fields"][
|
||||
"disallowed"
|
||||
] = disallowed
|
||||
|
||||
return HTTPStatus.OK, response
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
@@ -48,6 +49,14 @@ class PresenceStatusRestServlet(RestServlet):
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
# Ratelimiter for presence updates, keyed by requester.
|
||||
self._presence_per_user_limiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
cfg=hs.config.ratelimiting.rc_presence_per_user,
|
||||
)
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -82,6 +91,17 @@ class PresenceStatusRestServlet(RestServlet):
|
||||
if requester.user != user:
|
||||
raise AuthError(403, "Can only set your own presence state")
|
||||
|
||||
# ignore the presence update if the ratelimit is exceeded
|
||||
try:
|
||||
await self._presence_per_user_limiter.ratelimit(requester)
|
||||
except LimitExceededError as e:
|
||||
logger.debug("User presence ratelimit exceeded; ignoring it.")
|
||||
return 429, {
|
||||
"errcode": Codes.LIMIT_EXCEEDED,
|
||||
"error": "Too many requests",
|
||||
"retry_after_ms": e.retry_after_ms,
|
||||
}
|
||||
|
||||
state = {}
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
@@ -21,10 +21,13 @@
|
||||
|
||||
"""This module contains REST servlets to do with profile: /profile/<paths>"""
|
||||
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ProfileFields
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -33,7 +36,8 @@ from synapse.http.servlet import (
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, JsonValue, UserID
|
||||
from synapse.util.stringutils import is_namedspaced_grammar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -91,6 +95,11 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
@@ -101,9 +110,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
||||
new_name = content["displayname"]
|
||||
except Exception:
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg="Unable to parse name",
|
||||
errcode=Codes.BAD_JSON,
|
||||
400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
@@ -166,6 +173,11 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
@@ -232,7 +244,180 @@ class ProfileRestServlet(RestServlet):
|
||||
return 200, ret
|
||||
|
||||
|
||||
class UnstableProfileFieldRestServlet(RestServlet):
|
||||
PATTERNS = [
|
||||
re.compile(
|
||||
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
|
||||
)
|
||||
]
|
||||
CATEGORY = "Event sending requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester_user = None
|
||||
|
||||
if self.hs.config.server.require_auth_for_profile_requests:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user = requester.user
|
||||
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
field_value: JsonValue = await self.profile_handler.get_displayname(user)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
field_value = await self.profile_handler.get_avatar_url(user)
|
||||
else:
|
||||
field_value = await self.profile_handler.get_profile_field(user, field_name)
|
||||
|
||||
return 200, {field_name: field_value}
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
try:
|
||||
new_value = content[field_name]
|
||||
except KeyError:
|
||||
raise SynapseError(
|
||||
400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
|
||||
requester_suspended = (
|
||||
await self.hs.get_datastores().main.get_user_suspended_status(
|
||||
requester.user.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
if requester_suspended:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Updating profile while account is suspended is not allowed.",
|
||||
Codes.USER_ACCOUNT_SUSPENDED,
|
||||
)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
await self.profile_handler.set_displayname(
|
||||
user, requester, new_value, is_admin, propagate=propagate
|
||||
)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user, requester, new_value, is_admin, propagate=propagate
|
||||
)
|
||||
else:
|
||||
await self.profile_handler.set_profile_field(
|
||||
user, requester, field_name, new_value, is_admin
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
|
||||
requester_suspended = (
|
||||
await self.hs.get_datastores().main.get_user_suspended_status(
|
||||
requester.user.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
if requester_suspended:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Updating profile while account is suspended is not allowed.",
|
||||
Codes.USER_ACCOUNT_SUSPENDED,
|
||||
)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
await self.profile_handler.set_displayname(
|
||||
user, requester, "", is_admin, propagate=propagate
|
||||
)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user, requester, "", is_admin, propagate=propagate
|
||||
)
|
||||
else:
|
||||
await self.profile_handler.delete_profile_field(
|
||||
user, requester, field_name, is_admin
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# The specific displayname / avatar URL / custom field endpoints *must* appear
|
||||
# before their corresponding generic profile endpoint.
|
||||
ProfileDisplaynameRestServlet(hs).register(http_server)
|
||||
ProfileAvatarURLRestServlet(hs).register(http_server)
|
||||
ProfileRestServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc4133_enabled:
|
||||
UnstableProfileFieldRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -24,9 +24,10 @@ from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.errors import Codes, LimitExceededError, StoreError, SynapseError
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.events.utils import (
|
||||
SerializeEventConfig,
|
||||
format_event_for_client_v2_without_room_id,
|
||||
@@ -126,6 +127,13 @@ class SyncRestServlet(RestServlet):
|
||||
cache_name="sync_valid_filter",
|
||||
)
|
||||
|
||||
# Ratelimiter for presence updates, keyed by requester.
|
||||
self._presence_per_user_limiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
cfg=hs.config.ratelimiting.rc_presence_per_user,
|
||||
)
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
@@ -239,7 +247,14 @@ class SyncRestServlet(RestServlet):
|
||||
# send any outstanding server notices to the user.
|
||||
await self._server_notices_sender.on_user_syncing(user.to_string())
|
||||
|
||||
affect_presence = set_presence != PresenceState.OFFLINE
|
||||
# ignore the presence update if the ratelimit is exceeded but do not pause the request
|
||||
try:
|
||||
await self._presence_per_user_limiter.ratelimit(requester, pause=0.0)
|
||||
except LimitExceededError:
|
||||
affect_presence = False
|
||||
logger.debug("User set_presence ratelimit exceeded; ignoring it.")
|
||||
else:
|
||||
affect_presence = set_presence != PresenceState.OFFLINE
|
||||
|
||||
context = await self.presence_handler.user_syncing(
|
||||
user.to_string(),
|
||||
|
||||
@@ -172,6 +172,8 @@ class VersionsRestServlet(RestServlet):
|
||||
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
|
||||
# Simplified sliding sync
|
||||
"org.matrix.simplified_msc3575": msc3575_enabled,
|
||||
# Arbitrary key-value profile fields.
|
||||
"uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -391,7 +391,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
|
||||
return domain_specific_string.domain == self.hostname
|
||||
|
||||
def is_mine_id(self, string: str) -> bool:
|
||||
def is_mine_id(self, user_id: str) -> bool:
|
||||
"""Determines whether a user ID or room alias originates from this homeserver.
|
||||
|
||||
Returns:
|
||||
@@ -399,7 +399,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
homeserver.
|
||||
`False` otherwise, or if the user ID or room alias is malformed.
|
||||
"""
|
||||
localpart_hostname = string.split(":", 1)
|
||||
localpart_hostname = user_id.split(":", 1)
|
||||
if len(localpart_hostname) < 2:
|
||||
return False
|
||||
return localpart_hostname[1] == self.hostname
|
||||
|
||||
@@ -789,7 +789,7 @@ class BackgroundUpdater:
|
||||
# we may already have a half-built index. Let's just drop it
|
||||
# before trying to create it again.
|
||||
|
||||
sql = "DROP INDEX IF EXISTS %s" % (index_name,)
|
||||
sql = "DROP INDEX CONCURRENTLY IF EXISTS %s" % (index_name,)
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
|
||||
@@ -814,7 +814,7 @@ class BackgroundUpdater:
|
||||
|
||||
if replaces_index is not None:
|
||||
# We drop the old index as the new index has now been created.
|
||||
sql = f"DROP INDEX IF EXISTS {replaces_index}"
|
||||
sql = f"DROP INDEX CONCURRENTLY IF EXISTS {replaces_index}"
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
finally:
|
||||
|
||||
@@ -18,8 +18,13 @@
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.constants import ProfileFields
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@@ -27,13 +32,17 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict, JsonValue, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
# The number of bytes that the serialized profile can have.
|
||||
MAX_PROFILE_SIZE = 65536
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue:
|
||||
"""
|
||||
Get a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The custom profile field name.
|
||||
|
||||
Returns:
|
||||
The string value if the field exists, otherwise raises 404.
|
||||
"""
|
||||
|
||||
def get_profile_field(txn: LoggingTransaction) -> JsonValue:
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
field_path = f'$."{field_name}"'
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?)
|
||||
FROM profiles
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_path, field_name, user_id.localpart),
|
||||
)
|
||||
|
||||
# Test exists first since value being None is used for both
|
||||
# missing and a null JSON value.
|
||||
exists, value = cast(Tuple[bool, JsonValue], txn.fetchone())
|
||||
if not exists:
|
||||
raise StoreError(404, "No row found")
|
||||
return value
|
||||
|
||||
else:
|
||||
sql = """
|
||||
SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?)
|
||||
FROM profiles
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_path, field_path, user_id.localpart),
|
||||
)
|
||||
|
||||
# If value_type is None, then the value did not exist.
|
||||
value_type, value = cast(
|
||||
Tuple[Optional[str], JsonValue], txn.fetchone()
|
||||
)
|
||||
if not value_type:
|
||||
raise StoreError(404, "No row found")
|
||||
# If value_type is object or array, then need to deserialize the JSON.
|
||||
# Scalar values are properly returned directly.
|
||||
if value_type in ("object", "array"):
|
||||
assert isinstance(value, str)
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
return await self.db_pool.runInteraction("get_profile_field", get_profile_field)
|
||||
|
||||
async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]:
|
||||
"""
|
||||
Get all custom profile fields for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
|
||||
Returns:
|
||||
A dictionary of custom profile fields.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="fields",
|
||||
desc="get_profile_fields",
|
||||
)
|
||||
# The SQLite driver doesn't automatically convert JSON to
|
||||
# Python objects
|
||||
if isinstance(self.database_engine, Sqlite3Engine) and result:
|
||||
result = json.loads(result)
|
||||
return result or {}
|
||||
|
||||
async def create_profile(self, user_id: UserID) -> None:
|
||||
"""
|
||||
Create a blank profile for a user.
|
||||
@@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="create_profile",
|
||||
)
|
||||
|
||||
def _check_profile_size(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: UserID,
|
||||
new_field_name: str,
|
||||
new_value: JsonValue,
|
||||
) -> None:
|
||||
# For each entry there are 4 quotes (2 each for key and value), 1 colon,
|
||||
# and 1 comma.
|
||||
PER_VALUE_EXTRA = 6
|
||||
|
||||
# Add the size of the current custom profile fields, ignoring the entry
|
||||
# which will be overwritten.
|
||||
if isinstance(txn.database_engine, PostgresEngine):
|
||||
size_sql = """
|
||||
SELECT
|
||||
OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url)
|
||||
FROM profiles
|
||||
WHERE
|
||||
user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
size_sql,
|
||||
(new_field_name, user_id.localpart),
|
||||
)
|
||||
else:
|
||||
size_sql = """
|
||||
SELECT
|
||||
LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url)
|
||||
FROM profiles
|
||||
WHERE
|
||||
user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
size_sql,
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
(f'$."{new_field_name}"', user_id.localpart),
|
||||
)
|
||||
row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone())
|
||||
|
||||
# The values return null if the column is null.
|
||||
total_bytes = (
|
||||
# Discount the opening and closing braces to avoid double counting,
|
||||
# but add one for a comma.
|
||||
# -2 + 1 = -1
|
||||
(row[0] - 1 if row[0] else 0)
|
||||
+ (
|
||||
row[1] + len("displayname") + PER_VALUE_EXTRA
|
||||
if new_field_name != ProfileFields.DISPLAYNAME and row[1]
|
||||
else 0
|
||||
)
|
||||
+ (
|
||||
row[2] + len("avatar_url") + PER_VALUE_EXTRA
|
||||
if new_field_name != ProfileFields.AVATAR_URL and row[2]
|
||||
else 0
|
||||
)
|
||||
)
|
||||
|
||||
# Add the length of the field being added + the braces.
|
||||
total_bytes += len(encode_canonical_json({new_field_name: new_value}))
|
||||
|
||||
if total_bytes > MAX_PROFILE_SIZE:
|
||||
raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
async def set_profile_displayname(
|
||||
self, user_id: UserID, new_displayname: Optional[str]
|
||||
) -> None:
|
||||
@@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
name is removed.
|
||||
"""
|
||||
user_localpart = user_id.localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"displayname": new_displayname,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
desc="set_profile_displayname",
|
||||
|
||||
def set_profile_displayname(txn: LoggingTransaction) -> None:
|
||||
if new_displayname is not None:
|
||||
self._check_profile_size(
|
||||
txn, user_id, ProfileFields.DISPLAYNAME, new_displayname
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"displayname": new_displayname,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_profile_displayname", set_profile_displayname
|
||||
)
|
||||
|
||||
async def set_profile_avatar_url(
|
||||
@@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
removed.
|
||||
"""
|
||||
user_localpart = user_id.localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
|
||||
desc="set_profile_avatar_url",
|
||||
|
||||
def set_profile_avatar_url(txn: LoggingTransaction) -> None:
|
||||
if new_avatar_url is not None:
|
||||
self._check_profile_size(
|
||||
txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"avatar_url": new_avatar_url,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_profile_avatar_url", set_profile_avatar_url
|
||||
)
|
||||
|
||||
async def set_profile_field(
|
||||
self, user_id: UserID, field_name: str, new_value: JsonValue
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The name of the custom profile field.
|
||||
new_value: The value of the custom profile field.
|
||||
"""
|
||||
|
||||
# Encode to canonical JSON.
|
||||
canonical_value = encode_canonical_json(new_value)
|
||||
|
||||
def set_profile_field(txn: LoggingTransaction) -> None:
|
||||
self._check_profile_size(txn, user_id, field_name, new_value)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import Json
|
||||
|
||||
# Note that the || jsonb operator is not recursive, any duplicate
|
||||
# keys will be taken from the second value.
|
||||
sql = """
|
||||
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb))
|
||||
ON CONFLICT (user_id)
|
||||
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id.localpart,
|
||||
user_id.to_string(),
|
||||
field_name,
|
||||
# Pass as a JSON object since we have passing bytes disabled
|
||||
# at the database driver.
|
||||
Json(json.loads(canonical_value)),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# You may be tempted to use json_patch instead of providing the parameters
|
||||
# twice, but that recursively merges objects instead of replacing.
|
||||
sql = """
|
||||
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?)))
|
||||
ON CONFLICT (user_id)
|
||||
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?))
|
||||
"""
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
json_field_name = f'$."{field_name}"'
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id.localpart,
|
||||
user_id.to_string(),
|
||||
json_field_name,
|
||||
canonical_value,
|
||||
json_field_name,
|
||||
canonical_value,
|
||||
),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("set_profile_field", set_profile_field)
|
||||
|
||||
async def delete_profile_field(self, user_id: UserID, field_name: str) -> None:
|
||||
"""
|
||||
Remove a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The name of the custom profile field.
|
||||
"""
|
||||
|
||||
def delete_profile_field(txn: LoggingTransaction) -> None:
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
UPDATE profiles SET fields = fields - ?
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_name, user_id.localpart),
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE profiles SET fields = json_remove(fields, ?)
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
# This will error if field_name has double quotes in it.
|
||||
(f'$."{field_name}"', user_id.localpart),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_profile_field", delete_profile_field)
|
||||
|
||||
|
||||
class ProfileStore(ProfileWorkerStore):
|
||||
pass
|
||||
|
||||
@@ -1181,6 +1181,50 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked.
|
||||
|
||||
Can be called multiple times (though we'll only track the last user to
|
||||
block this room).
|
||||
|
||||
Can be called on a room unknown to this homeserver.
|
||||
|
||||
Args:
|
||||
room_id: Room to block
|
||||
user_id: Who blocked it
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={"user_id": user_id},
|
||||
desc="block_room",
|
||||
)
|
||||
await self.db_pool.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
async def unblock_room(self, room_id: str) -> None:
|
||||
"""Remove the room from blocking list.
|
||||
|
||||
Args:
|
||||
room_id: Room to unblock
|
||||
"""
|
||||
await self.db_pool.simple_delete(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
desc="unblock_room",
|
||||
)
|
||||
await self.db_pool.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
async def get_rooms_for_retention_period_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||
) -> Dict[str, RetentionPolicy]:
|
||||
@@ -2500,50 +2544,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
)
|
||||
return next_id
|
||||
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked.
|
||||
|
||||
Can be called multiple times (though we'll only track the last user to
|
||||
block this room).
|
||||
|
||||
Can be called on a room unknown to this homeserver.
|
||||
|
||||
Args:
|
||||
room_id: Room to block
|
||||
user_id: Who blocked it
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={"user_id": user_id},
|
||||
desc="block_room",
|
||||
)
|
||||
await self.db_pool.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
async def unblock_room(self, room_id: str) -> None:
|
||||
"""Remove the room from blocking list.
|
||||
|
||||
Args:
|
||||
room_id: Room to unblock
|
||||
"""
|
||||
await self.db_pool.simple_delete(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
desc="unblock_room",
|
||||
)
|
||||
await self.db_pool.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
async def clear_partial_state_room(self, room_id: str) -> Optional[int]:
|
||||
"""Clears the partial state flag for a room.
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 Patrick Cloke
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Custom profile fields.
|
||||
ALTER TABLE profiles ADD COLUMN fields JSONB;
|
||||
@@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
||||
#
|
||||
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||
|
||||
# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar
|
||||
#
|
||||
# At least one character, less than or equal to 255 characters. Must start with
|
||||
# a-z, the rest is a-z, 0-9, -, _, or ..
|
||||
#
|
||||
# This doesn't check anything about validity of namespaces.
|
||||
NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$")
|
||||
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
"""Generate a cryptographically secure string of random letters.
|
||||
@@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_namedspaced_grammar(s: str) -> bool:
|
||||
return bool(NAMESPACED_GRAMMAR.match(s))
|
||||
|
||||
|
||||
def assert_valid_client_secret(client_secret: str) -> None:
|
||||
"""Validate that a given string matches the client_secret defined by the spec"""
|
||||
if (
|
||||
|
||||
@@ -39,7 +39,7 @@ from synapse.module_api import ModuleApi
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import StreamToken, create_requester
|
||||
from synapse.types import StreamToken, UserID, UserInfo, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.handlers.test_sync import generate_sync_config
|
||||
@@ -349,6 +349,169 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
|
||||
join_updates, _ = sync_join(self, invited_user_id)
|
||||
self.assertEqual(len(join_updates), 0)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"auto_accept_invites": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
async def test_ignore_invite_for_missing_user(self) -> None:
|
||||
"""Tests that receiving an invite for a missing user is ignored."""
|
||||
inviting_user_id = self.register_user("inviter", "pass")
|
||||
inviting_user_tok = self.login("inviter", "pass")
|
||||
|
||||
# A local user who receives an invite
|
||||
invited_user_id = "@fake:" + self.hs.config.server.server_name
|
||||
|
||||
# Create a room and send an invite to the other user
|
||||
room_id = self.helper.create_room_as(
|
||||
inviting_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
self.helper.invite(
|
||||
room_id,
|
||||
inviting_user_id,
|
||||
invited_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
join_updates, _ = sync_join(self, inviting_user_id)
|
||||
# Assert that the last event in the room was not a member event for the target user.
|
||||
self.assertEqual(
|
||||
join_updates[0].timeline.events[-1].content["membership"], "invite"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"auto_accept_invites": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
async def test_ignore_invite_for_deactivated_user(self) -> None:
|
||||
"""Tests that receiving an invite for a deactivated user is ignored."""
|
||||
inviting_user_id = self.register_user("inviter", "pass", admin=True)
|
||||
inviting_user_tok = self.login("inviter", "pass")
|
||||
|
||||
# A local user who receives an invite
|
||||
invited_user_id = self.register_user("invitee", "pass")
|
||||
|
||||
# Create a room and send an invite to the other user
|
||||
room_id = self.helper.create_room_as(
|
||||
inviting_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/_synapse/admin/v2/users/%s" % invited_user_id,
|
||||
{"deactivated": True},
|
||||
access_token=inviting_user_tok,
|
||||
)
|
||||
|
||||
assert channel.code == 200
|
||||
|
||||
self.helper.invite(
|
||||
room_id,
|
||||
inviting_user_id,
|
||||
invited_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
join_updates, b = sync_join(self, inviting_user_id)
|
||||
# Assert that the last event in the room was not a member event for the target user.
|
||||
self.assertEqual(
|
||||
join_updates[0].timeline.events[-1].content["membership"], "invite"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"auto_accept_invites": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
async def test_ignore_invite_for_suspended_user(self) -> None:
|
||||
"""Tests that receiving an invite for a suspended user is ignored."""
|
||||
inviting_user_id = self.register_user("inviter", "pass", admin=True)
|
||||
inviting_user_tok = self.login("inviter", "pass")
|
||||
|
||||
# A local user who receives an invite
|
||||
invited_user_id = self.register_user("invitee", "pass")
|
||||
|
||||
# Create a room and send an invite to the other user
|
||||
room_id = self.helper.create_room_as(
|
||||
inviting_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_synapse/admin/v1/suspend/{invited_user_id}",
|
||||
{"suspend": True},
|
||||
access_token=inviting_user_tok,
|
||||
)
|
||||
|
||||
assert channel.code == 200
|
||||
|
||||
self.helper.invite(
|
||||
room_id,
|
||||
inviting_user_id,
|
||||
invited_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
join_updates, b = sync_join(self, inviting_user_id)
|
||||
# Assert that the last event in the room was not a member event for the target user.
|
||||
self.assertEqual(
|
||||
join_updates[0].timeline.events[-1].content["membership"], "invite"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"auto_accept_invites": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
async def test_ignore_invite_for_locked_user(self) -> None:
|
||||
"""Tests that receiving an invite for a suspended user is ignored."""
|
||||
inviting_user_id = self.register_user("inviter", "pass", admin=True)
|
||||
inviting_user_tok = self.login("inviter", "pass")
|
||||
|
||||
# A local user who receives an invite
|
||||
invited_user_id = self.register_user("invitee", "pass")
|
||||
|
||||
# Create a room and send an invite to the other user
|
||||
room_id = self.helper.create_room_as(
|
||||
inviting_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_synapse/admin/v2/users/{invited_user_id}",
|
||||
{"locked": True},
|
||||
access_token=inviting_user_tok,
|
||||
)
|
||||
|
||||
assert channel.code == 200
|
||||
|
||||
self.helper.invite(
|
||||
room_id,
|
||||
inviting_user_id,
|
||||
invited_user_id,
|
||||
tok=inviting_user_tok,
|
||||
)
|
||||
|
||||
join_updates, b = sync_join(self, inviting_user_id)
|
||||
# Assert that the last event in the room was not a member event for the target user.
|
||||
self.assertEqual(
|
||||
join_updates[0].timeline.events[-1].content["membership"], "invite"
|
||||
)
|
||||
|
||||
|
||||
_request_key = 0
|
||||
|
||||
@@ -647,6 +810,22 @@ def create_module(
|
||||
module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
|
||||
module_api.worker_name = worker_name
|
||||
module_api.sleep.return_value = make_multiple_awaitable(None)
|
||||
module_api.get_userinfo_by_id.return_value = UserInfo(
|
||||
user_id=UserID.from_string("@user:test"),
|
||||
is_admin=False,
|
||||
is_guest=False,
|
||||
consent_server_notice_sent=None,
|
||||
consent_ts=None,
|
||||
consent_version=None,
|
||||
appservice_id=None,
|
||||
creation_ts=0,
|
||||
user_type=None,
|
||||
is_deactivated=False,
|
||||
locked=False,
|
||||
is_shadow_banned=False,
|
||||
approved=True,
|
||||
suspended=False,
|
||||
)
|
||||
|
||||
if config_override is None:
|
||||
config_override = {}
|
||||
|
||||
161
tests/federation/test_federation_devices.py
Normal file
161
tests/federation/test_federation_devices.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.handlers.device import DeviceListUpdater
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceListResyncTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_retry_device_list_resync(self) -> None:
|
||||
"""Tests that device lists are marked as stale if they couldn't be synced, and
|
||||
that stale device lists are retried periodically.
|
||||
"""
|
||||
remote_user_id = "@john:test_remote"
|
||||
remote_origin = "test_remote"
|
||||
|
||||
# Track the number of attempts to resync the user's device list.
|
||||
self.resync_attempts = 0
|
||||
|
||||
# When this function is called, increment the number of resync attempts (only if
|
||||
# we're querying devices for the right user ID), then raise a
|
||||
# NotRetryingDestination error to fail the resync gracefully.
|
||||
def query_user_devices(
|
||||
destination: str, user_id: str, timeout: int = 30000
|
||||
) -> JsonDict:
|
||||
if user_id == remote_user_id:
|
||||
self.resync_attempts += 1
|
||||
|
||||
raise NotRetryingDestination(0, 0, destination)
|
||||
|
||||
# Register the mock on the federation client.
|
||||
federation_client = self.hs.get_federation_client()
|
||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign]
|
||||
|
||||
# Register a mock on the store so that the incoming update doesn't fail because
|
||||
# we don't share a room with the user.
|
||||
self.store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
|
||||
|
||||
# Manually inject a fake device list update. We need this update to include at
|
||||
# least one prev_id so that the user's device list will need to be retried.
|
||||
device_list_updater = self.hs.get_device_handler().device_list_updater
|
||||
assert isinstance(device_list_updater, DeviceListUpdater)
|
||||
self.get_success(
|
||||
device_list_updater.incoming_device_list_update(
|
||||
origin=remote_origin,
|
||||
edu_content={
|
||||
"deleted": False,
|
||||
"device_display_name": "Mobile",
|
||||
"device_id": "QBUAZIFURK",
|
||||
"prev_id": [5],
|
||||
"stream_id": 6,
|
||||
"user_id": remote_user_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Check that there was one resync attempt.
|
||||
self.assertEqual(self.resync_attempts, 1)
|
||||
|
||||
# Check that the resync attempt failed and caused the user's device list to be
|
||||
# marked as stale.
|
||||
need_resync = self.get_success(
|
||||
self.store.get_user_ids_requiring_device_list_resync()
|
||||
)
|
||||
self.assertIn(remote_user_id, need_resync)
|
||||
|
||||
# Check that waiting for 30 seconds caused Synapse to retry resyncing the device
|
||||
# list.
|
||||
self.reactor.advance(30)
|
||||
self.assertEqual(self.resync_attempts, 2)
|
||||
|
||||
def test_cross_signing_keys_retry(self) -> None:
|
||||
"""Tests that resyncing a device list correctly processes cross-signing keys from
|
||||
the remote server.
|
||||
"""
|
||||
remote_user_id = "@john:test_remote"
|
||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||
|
||||
# Register mock device list retrieval on the federation client.
|
||||
federation_client = self.hs.get_federation_client()
|
||||
federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"user_id": remote_user_id,
|
||||
"stream_id": 1,
|
||||
"devices": [],
|
||||
"master_key": {
|
||||
"user_id": remote_user_id,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + remote_master_key: remote_master_key},
|
||||
},
|
||||
"self_signing_key": {
|
||||
"user_id": remote_user_id,
|
||||
"usage": ["self_signing"],
|
||||
"keys": {
|
||||
"ed25519:" + remote_self_signing_key: remote_self_signing_key
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Resync the device list.
|
||||
device_handler = self.hs.get_device_handler()
|
||||
self.get_success(
|
||||
device_handler.device_list_updater.multi_user_device_resync(
|
||||
[remote_user_id]
|
||||
),
|
||||
)
|
||||
|
||||
# Retrieve the cross-signing keys for this user.
|
||||
keys = self.get_success(
|
||||
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
|
||||
)
|
||||
self.assertIn(remote_user_id, keys)
|
||||
key = keys[remote_user_id]
|
||||
assert key is not None
|
||||
|
||||
# Check that the master key is the one returned by the mock.
|
||||
master_key = key["master"]
|
||||
self.assertEqual(len(master_key["keys"]), 1)
|
||||
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
|
||||
self.assertTrue(remote_master_key in master_key["keys"].values())
|
||||
|
||||
# Check that the self-signing key is the one returned by the mock.
|
||||
self_signing_key = key["self_signing"]
|
||||
self.assertEqual(len(self_signing_key["keys"]), 1)
|
||||
self.assertTrue(
|
||||
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
|
||||
)
|
||||
self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
|
||||
671
tests/federation/test_federation_out_of_band_membership.py
Normal file
671
tests/federation/test_federation_out_of_band_membership.py
Normal file
@@ -0,0 +1,671 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Optional, Set, Tuple, TypeVar, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import strip_event
|
||||
from synapse.federation.federation_base import (
|
||||
event_from_pdu_json,
|
||||
)
|
||||
from synapse.federation.transport.client import SendJoinResponse
|
||||
from synapse.http.matrixfederationclient import (
|
||||
ByteParser,
|
||||
)
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room, sync
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, MutableStateMap, StateMap
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
StateValues,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import test_timeout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def required_state_json_to_state_map(required_state: Any) -> StateMap[EventBase]:
|
||||
state_map: MutableStateMap[EventBase] = {}
|
||||
|
||||
# Scrutinize JSON values to ensure it's in the expected format
|
||||
if isinstance(required_state, list):
|
||||
for state_event_dict in required_state:
|
||||
# Yell because we're in a test and this is unexpected
|
||||
assert isinstance(
|
||||
state_event_dict, dict
|
||||
), "`required_state` should be a list of event dicts"
|
||||
|
||||
event_type = state_event_dict["type"]
|
||||
event_state_key = state_event_dict["state_key"]
|
||||
|
||||
# Yell because we're in a test and this is unexpected
|
||||
assert isinstance(
|
||||
event_type, str
|
||||
), "Each event in `required_state` should have a string `type`"
|
||||
assert isinstance(
|
||||
event_state_key, str
|
||||
), "Each event in `required_state` should have a string `state_key`"
|
||||
|
||||
state_map[(event_type, event_state_key)] = make_event_from_dict(
|
||||
state_event_dict
|
||||
)
|
||||
else:
|
||||
# Yell because we're in a test and this is unexpected
|
||||
raise AssertionError("`required_state` should be a list of event dicts")
|
||||
|
||||
return state_map
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class RemoteRoomJoinResult:
|
||||
remote_room_id: str
|
||||
room_version: RoomVersion
|
||||
remote_room_creator_user_id: str
|
||||
local_user1_id: str
|
||||
local_user1_tok: str
|
||||
state_map: StateMap[EventBase]
|
||||
|
||||
|
||||
class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase):
|
||||
"""
|
||||
Tests to make sure that interactions with out-of-band membership (outliers) works as
|
||||
expected.
|
||||
|
||||
- invites received over federation, before we join the room
|
||||
- *rejections* for said invites
|
||||
|
||||
See the "Out-of-band membership events" section in
|
||||
`docs/development/room-dag-concepts.md` for more information.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
conf = super().default_config()
|
||||
# Federation sending is disabled by default in the test environment
|
||||
# so we need to enable it like this.
|
||||
conf["federation_sender_instances"] = ["master"]
|
||||
|
||||
return conf
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.federation_http_client = Mock(
|
||||
# The problem with using `spec=MatrixFederationHttpClient` here is that it
|
||||
# requires everything to be mocked which is a lot of work that I don't want
|
||||
# to do when the code only uses a few methods (`get_json` and `put_json`).
|
||||
)
|
||||
return self.setup_test_homeserver(
|
||||
federation_http_client=self.federation_http_client
|
||||
)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
def do_sync(
|
||||
self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str
|
||||
) -> Tuple[JsonDict, str]:
|
||||
"""Do a sliding sync request with given body.
|
||||
|
||||
Asserts the request was successful.
|
||||
|
||||
Attributes:
|
||||
sync_body: The full request body to use
|
||||
since: Optional since token
|
||||
tok: Access token to use
|
||||
|
||||
Returns:
|
||||
A tuple of the response body and the `pos` field.
|
||||
"""
|
||||
|
||||
sync_path = self.sync_endpoint
|
||||
if since:
|
||||
sync_path += f"?pos={since}"
|
||||
|
||||
channel = self.make_request(
|
||||
method="POST",
|
||||
path=sync_path,
|
||||
content=sync_body,
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
return channel.json_body, channel.json_body["pos"]
|
||||
|
||||
def _invite_local_user_to_remote_room_and_join(self) -> RemoteRoomJoinResult:
|
||||
"""
|
||||
Helper to reproduce this scenario:
|
||||
|
||||
1. The remote user invites our local user to a room on their remote server (which
|
||||
creates an out-of-band invite membership for user1 on our local server).
|
||||
2. The local user notices the invite from `/sync`.
|
||||
3. The local user joins the room.
|
||||
4. The local user can see that they are now joined to the room from `/sync`.
|
||||
"""
|
||||
|
||||
# Create a local user
|
||||
local_user1_id = self.register_user("user1", "pass")
|
||||
local_user1_tok = self.login(local_user1_id, "pass")
|
||||
|
||||
# Create a remote room
|
||||
room_creator_user_id = f"@remote-user:{self.OTHER_SERVER_NAME}"
|
||||
remote_room_id = f"!remote-room:{self.OTHER_SERVER_NAME}"
|
||||
room_version = RoomVersions.V10
|
||||
|
||||
room_create_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": remote_room_id,
|
||||
"sender": room_creator_user_id,
|
||||
"depth": 1,
|
||||
"origin_server_ts": 1,
|
||||
"type": EventTypes.Create,
|
||||
"state_key": "",
|
||||
"content": {
|
||||
# The `ROOM_CREATOR` field could be removed if we used a room
|
||||
# version > 10 (in favor of relying on `sender`)
|
||||
EventContentFields.ROOM_CREATOR: room_creator_user_id,
|
||||
EventContentFields.ROOM_VERSION: room_version.identifier,
|
||||
},
|
||||
"auth_events": [],
|
||||
"prev_events": [],
|
||||
}
|
||||
),
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
creator_membership_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": remote_room_id,
|
||||
"sender": room_creator_user_id,
|
||||
"depth": 2,
|
||||
"origin_server_ts": 2,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": room_creator_user_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [room_create_event.event_id],
|
||||
"prev_events": [room_create_event.event_id],
|
||||
}
|
||||
),
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
# From the remote homeserver, invite user1 on the local homserver
|
||||
user1_invite_membership_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": remote_room_id,
|
||||
"sender": room_creator_user_id,
|
||||
"depth": 3,
|
||||
"origin_server_ts": 3,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": local_user1_id,
|
||||
"content": {"membership": Membership.INVITE},
|
||||
"auth_events": [
|
||||
room_create_event.event_id,
|
||||
creator_membership_event.event_id,
|
||||
],
|
||||
"prev_events": [creator_membership_event.event_id],
|
||||
}
|
||||
),
|
||||
room_version=room_version,
|
||||
)
|
||||
channel = self.make_signed_federation_request(
|
||||
"PUT",
|
||||
f"/_matrix/federation/v2/invite/{remote_room_id}/{user1_invite_membership_event.event_id}",
|
||||
content={
|
||||
"event": user1_invite_membership_event.get_dict(),
|
||||
"invite_room_state": [
|
||||
strip_event(room_create_event),
|
||||
],
|
||||
"room_version": room_version.identifier,
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
sync_body = {
|
||||
"lists": {
|
||||
"foo-list": {
|
||||
"ranges": [[0, 1]],
|
||||
"required_state": [(EventTypes.Member, StateValues.WILDCARD)],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Sync until the local user1 can see the invite
|
||||
with test_timeout(
|
||||
3,
|
||||
"Unable to find user1's invite event in the room",
|
||||
):
|
||||
while True:
|
||||
response_body, _ = self.do_sync(sync_body, tok=local_user1_tok)
|
||||
if (
|
||||
remote_room_id in response_body["rooms"].keys()
|
||||
# If they have `invite_state` for the room, they are invited
|
||||
and len(
|
||||
response_body["rooms"][remote_room_id].get("invite_state", [])
|
||||
)
|
||||
> 0
|
||||
):
|
||||
break
|
||||
|
||||
# Prevent tight-looping to allow the `test_timeout` to work
|
||||
time.sleep(0.1)
|
||||
|
||||
user1_join_membership_event_template = make_event_from_dict(
|
||||
{
|
||||
"room_id": remote_room_id,
|
||||
"sender": local_user1_id,
|
||||
"depth": 4,
|
||||
"origin_server_ts": 4,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": local_user1_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [
|
||||
room_create_event.event_id,
|
||||
user1_invite_membership_event.event_id,
|
||||
],
|
||||
"prev_events": [user1_invite_membership_event.event_id],
|
||||
},
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Mock the remote homeserver responding to our HTTP requests
|
||||
#
|
||||
# We're going to mock the following endpoints so that user1 can join the remote room:
|
||||
# - GET /_matrix/federation/v1/make_join/{room_id}/{user_id}
|
||||
# - PUT /_matrix/federation/v2/send_join/{room_id}/{user_id}
|
||||
#
|
||||
async def get_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> Union[JsonDict, T]:
|
||||
if (
|
||||
path
|
||||
== f"/_matrix/federation/v1/make_join/{urllib.parse.quote_plus(remote_room_id)}/{urllib.parse.quote_plus(local_user1_id)}"
|
||||
):
|
||||
return {
|
||||
"event": user1_join_membership_event_template.get_pdu_json(),
|
||||
"room_version": room_version.identifier,
|
||||
}
|
||||
|
||||
raise NotImplementedError(
|
||||
"We have not mocked a response for `get_json(...)` for the following endpoint yet: "
|
||||
+ f"{destination}{path}"
|
||||
)
|
||||
|
||||
self.federation_http_client.get_json.side_effect = get_json
|
||||
|
||||
# PDU's that hs1 sent to hs2
|
||||
collected_pdus_from_hs1_federation_send: Set[str] = set()
|
||||
|
||||
async def put_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryParams] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
backoff_on_all_error_codes: bool = False,
|
||||
) -> Union[JsonDict, T, SendJoinResponse]:
|
||||
if (
|
||||
path.startswith(
|
||||
f"/_matrix/federation/v2/send_join/{urllib.parse.quote_plus(remote_room_id)}/"
|
||||
)
|
||||
and data is not None
|
||||
and data.get("type") == EventTypes.Member
|
||||
and data.get("state_key") == local_user1_id
|
||||
# We're assuming this is a `ByteParser[SendJoinResponse]`
|
||||
and parser is not None
|
||||
):
|
||||
# As the remote server, we need to sign the event before sending it back
|
||||
user1_join_membership_event_signed = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(data),
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
# Since they passed in a `parser`, we need to return the type that
|
||||
# they're expecting instead of just a `JsonDict`
|
||||
return SendJoinResponse(
|
||||
auth_events=[
|
||||
room_create_event,
|
||||
user1_invite_membership_event,
|
||||
],
|
||||
state=[
|
||||
room_create_event,
|
||||
creator_membership_event,
|
||||
user1_invite_membership_event,
|
||||
],
|
||||
event_dict=user1_join_membership_event_signed.get_pdu_json(),
|
||||
event=user1_join_membership_event_signed,
|
||||
members_omitted=False,
|
||||
servers_in_room=[
|
||||
self.OTHER_SERVER_NAME,
|
||||
],
|
||||
)
|
||||
|
||||
if path.startswith("/_matrix/federation/v1/send/") and data is not None:
|
||||
for pdu in data.get("pdus", []):
|
||||
event = event_from_pdu_json(pdu, room_version)
|
||||
collected_pdus_from_hs1_federation_send.add(event.event_id)
|
||||
|
||||
# Just acknowledge everything hs1 is trying to send hs2
|
||||
return {
|
||||
event_from_pdu_json(pdu, room_version).event_id: {}
|
||||
for pdu in data.get("pdus", [])
|
||||
}
|
||||
|
||||
raise NotImplementedError(
|
||||
"We have not mocked a response for `put_json(...)` for the following endpoint yet: "
|
||||
+ f"{destination}{path} with the following body data: {data}"
|
||||
)
|
||||
|
||||
self.federation_http_client.put_json.side_effect = put_json
|
||||
|
||||
# User1 joins the room
|
||||
self.helper.join(remote_room_id, local_user1_id, tok=local_user1_tok)
|
||||
|
||||
# Reset the mocks now that user1 has joined the room
|
||||
self.federation_http_client.get_json.side_effect = None
|
||||
self.federation_http_client.put_json.side_effect = None
|
||||
|
||||
# Sync until the local user1 can see that they are now joined to the room
|
||||
with test_timeout(
|
||||
3,
|
||||
"Unable to find user1's join event in the room",
|
||||
):
|
||||
while True:
|
||||
response_body, _ = self.do_sync(sync_body, tok=local_user1_tok)
|
||||
if remote_room_id in response_body["rooms"].keys():
|
||||
required_state_map = required_state_json_to_state_map(
|
||||
response_body["rooms"][remote_room_id]["required_state"]
|
||||
)
|
||||
if (
|
||||
required_state_map.get((EventTypes.Member, local_user1_id))
|
||||
is not None
|
||||
):
|
||||
break
|
||||
|
||||
# Prevent tight-looping to allow the `test_timeout` to work
|
||||
time.sleep(0.1)
|
||||
|
||||
# Nothing needs to be sent from hs1 to hs2 since we already let the other
|
||||
# homeserver know by doing the `/make_join` and `/send_join` dance.
|
||||
self.assertIncludes(
|
||||
collected_pdus_from_hs1_federation_send,
|
||||
set(),
|
||||
exact=True,
|
||||
message="Didn't expect any events to be sent from hs1 over federation to hs2",
|
||||
)
|
||||
|
||||
return RemoteRoomJoinResult(
|
||||
remote_room_id=remote_room_id,
|
||||
room_version=room_version,
|
||||
remote_room_creator_user_id=room_creator_user_id,
|
||||
local_user1_id=local_user1_id,
|
||||
local_user1_tok=local_user1_tok,
|
||||
state_map=self.get_success(
|
||||
self.storage_controllers.state.get_current_state(remote_room_id)
|
||||
),
|
||||
)
|
||||
|
||||
def test_can_join_from_out_of_band_invite(self) -> None:
|
||||
"""
|
||||
Test to make sure that we can join a room that we were invited to over
|
||||
federation; even if our server has never participated in the room before.
|
||||
"""
|
||||
self._invite_local_user_to_remote_room_and_join()
|
||||
|
||||
@parameterized.expand(
|
||||
[("accept invite", Membership.JOIN), ("reject invite", Membership.LEAVE)]
|
||||
)
|
||||
def test_can_x_from_out_of_band_invite_after_we_are_already_participating_in_the_room(
|
||||
self, _test_description: str, membership_action: str
|
||||
) -> None:
|
||||
"""
|
||||
Test to make sure that we can do either a) join the room (accept the invite) or
|
||||
b) reject the invite after being invited to over federation; even if we are
|
||||
already participating in the room.
|
||||
|
||||
This is a regression test to make sure we stress the scenario where even though
|
||||
we are already participating in the room, local users can still react to invites
|
||||
regardless of whether the remote server has told us about the invite event (via
|
||||
a federation `/send` transaction) and we have de-outliered the invite event.
|
||||
Previously, we would mistakenly throw an error saying the user wasn't in the
|
||||
room when they tried to join or reject the invite.
|
||||
"""
|
||||
remote_room_join_result = self._invite_local_user_to_remote_room_and_join()
|
||||
remote_room_id = remote_room_join_result.remote_room_id
|
||||
room_version = remote_room_join_result.room_version
|
||||
|
||||
# Create another local user
|
||||
local_user2_id = self.register_user("user2", "pass")
|
||||
local_user2_tok = self.login(local_user2_id, "pass")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# PDU's that hs1 sent to hs2
|
||||
collected_pdus_from_hs1_federation_send: Set[str] = set()
|
||||
|
||||
async def put_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryParams] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
backoff_on_all_error_codes: bool = False,
|
||||
) -> Union[JsonDict, T]:
|
||||
if path.startswith("/_matrix/federation/v1/send/") and data is not None:
|
||||
for pdu in data.get("pdus", []):
|
||||
event = event_from_pdu_json(pdu, room_version)
|
||||
collected_pdus_from_hs1_federation_send.add(event.event_id)
|
||||
|
||||
# Just acknowledge everything hs1 is trying to send hs2
|
||||
return {
|
||||
event_from_pdu_json(pdu, room_version).event_id: {}
|
||||
for pdu in data.get("pdus", [])
|
||||
}
|
||||
|
||||
raise NotImplementedError(
|
||||
"We have not mocked a response for `put_json(...)` for the following endpoint yet: "
|
||||
+ f"{destination}{path} with the following body data: {data}"
|
||||
)
|
||||
|
||||
self.federation_http_client.put_json.side_effect = put_json
|
||||
|
||||
# From the remote homeserver, invite user2 on the local homserver
|
||||
user2_invite_membership_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": remote_room_id,
|
||||
"sender": remote_room_join_result.remote_room_creator_user_id,
|
||||
"depth": 5,
|
||||
"origin_server_ts": 5,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": local_user2_id,
|
||||
"content": {"membership": Membership.INVITE},
|
||||
"auth_events": [
|
||||
remote_room_join_result.state_map[
|
||||
(EventTypes.Create, "")
|
||||
].event_id,
|
||||
remote_room_join_result.state_map[
|
||||
(
|
||||
EventTypes.Member,
|
||||
remote_room_join_result.remote_room_creator_user_id,
|
||||
)
|
||||
].event_id,
|
||||
],
|
||||
"prev_events": [
|
||||
remote_room_join_result.state_map[
|
||||
(EventTypes.Member, remote_room_join_result.local_user1_id)
|
||||
].event_id
|
||||
],
|
||||
}
|
||||
),
|
||||
room_version=room_version,
|
||||
)
|
||||
channel = self.make_signed_federation_request(
|
||||
"PUT",
|
||||
f"/_matrix/federation/v2/invite/{remote_room_id}/{user2_invite_membership_event.event_id}",
|
||||
content={
|
||||
"event": user2_invite_membership_event.get_dict(),
|
||||
"invite_room_state": [
|
||||
strip_event(
|
||||
remote_room_join_result.state_map[(EventTypes.Create, "")]
|
||||
),
|
||||
],
|
||||
"room_version": room_version.identifier,
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
sync_body = {
|
||||
"lists": {
|
||||
"foo-list": {
|
||||
"ranges": [[0, 1]],
|
||||
"required_state": [(EventTypes.Member, StateValues.WILDCARD)],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Sync until the local user2 can see the invite
|
||||
with test_timeout(
|
||||
3,
|
||||
"Unable to find user2's invite event in the room",
|
||||
):
|
||||
while True:
|
||||
response_body, _ = self.do_sync(sync_body, tok=local_user2_tok)
|
||||
if (
|
||||
remote_room_id in response_body["rooms"].keys()
|
||||
# If they have `invite_state` for the room, they are invited
|
||||
and len(
|
||||
response_body["rooms"][remote_room_id].get("invite_state", [])
|
||||
)
|
||||
> 0
|
||||
):
|
||||
break
|
||||
|
||||
# Prevent tight-looping to allow the `test_timeout` to work
|
||||
time.sleep(0.1)
|
||||
|
||||
if membership_action == Membership.JOIN:
|
||||
# User2 joins the room
|
||||
join_event = self.helper.join(
|
||||
remote_room_join_result.remote_room_id,
|
||||
local_user2_id,
|
||||
tok=local_user2_tok,
|
||||
)
|
||||
expected_pdu_event_id = join_event["event_id"]
|
||||
elif membership_action == Membership.LEAVE:
|
||||
# User2 rejects the invite
|
||||
leave_event = self.helper.leave(
|
||||
remote_room_join_result.remote_room_id,
|
||||
local_user2_id,
|
||||
tok=local_user2_tok,
|
||||
)
|
||||
expected_pdu_event_id = leave_event["event_id"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"This test does not support this membership action yet"
|
||||
)
|
||||
|
||||
# Sync until the local user2 can see their new membership in the room
|
||||
with test_timeout(
|
||||
3,
|
||||
"Unable to find user2's new membership event in the room",
|
||||
):
|
||||
while True:
|
||||
response_body, _ = self.do_sync(sync_body, tok=local_user2_tok)
|
||||
if membership_action == Membership.JOIN:
|
||||
if remote_room_id in response_body["rooms"].keys():
|
||||
required_state_map = required_state_json_to_state_map(
|
||||
response_body["rooms"][remote_room_id]["required_state"]
|
||||
)
|
||||
if (
|
||||
required_state_map.get((EventTypes.Member, local_user2_id))
|
||||
is not None
|
||||
):
|
||||
break
|
||||
elif membership_action == Membership.LEAVE:
|
||||
if remote_room_id not in response_body["rooms"].keys():
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"This test does not support this membership action yet"
|
||||
)
|
||||
|
||||
# Prevent tight-looping to allow the `test_timeout` to work
|
||||
time.sleep(0.1)
|
||||
|
||||
# Make sure that we let hs2 know about the new membership event
|
||||
self.assertIncludes(
|
||||
collected_pdus_from_hs1_federation_send,
|
||||
{expected_pdu_event_id},
|
||||
exact=True,
|
||||
message="Expected to find the event ID of the user2 membership to be sent from hs1 over federation to hs2",
|
||||
)
|
||||
@@ -20,14 +20,21 @@
|
||||
#
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import FederationError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.federation_base import event_from_pdu_json
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
@@ -85,6 +92,163 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(500, channel.code, channel.result)
|
||||
|
||||
|
||||
def _create_acl_event(content: JsonDict) -> EventBase:
|
||||
return make_event_from_dict(
|
||||
{
|
||||
"room_id": "!a:b",
|
||||
"event_id": "$a:b",
|
||||
"type": "m.room.server_acls",
|
||||
"sender": "@a:b",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MessageAcceptTests(unittest.FederatingHomeserverTestCase):
|
||||
"""
|
||||
Tests to make sure that we don't accept flawed events from federation (incoming).
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock()
|
||||
return self.setup_test_homeserver(federation_http_client=self.http_client)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.storage_controllers = hs.get_storage_controllers()
|
||||
self.federation_event_handler = self.hs.get_federation_event_handler()
|
||||
|
||||
# Create a local room
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
self.room_id = self.helper.create_room_as(
|
||||
user1_id, tok=user1_tok, is_public=True
|
||||
)
|
||||
|
||||
state_map = self.get_success(
|
||||
self.storage_controllers.state.get_current_state(self.room_id)
|
||||
)
|
||||
|
||||
# Figure out what the forward extremities in the room are (the most recent
|
||||
# events that aren't tied into the DAG)
|
||||
forward_extremity_event_ids = self.get_success(
|
||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
||||
)
|
||||
|
||||
# Join a remote user to the room that will attempt to send bad events
|
||||
self.remote_bad_user_id = f"@baduser:{self.OTHER_SERVER_NAME}"
|
||||
self.remote_bad_user_join_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": self.remote_bad_user_id,
|
||||
"state_key": self.remote_bad_user_id,
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
],
|
||||
"prev_events": list(forward_extremity_event_ids),
|
||||
}
|
||||
),
|
||||
room_version=RoomVersions.V10,
|
||||
)
|
||||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
self.federation_event_handler.on_receive_pdu(
|
||||
self.OTHER_SERVER_NAME, self.remote_bad_user_join_event
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Make sure we actually joined the room
|
||||
self.assertEqual(
|
||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
|
||||
{self.remote_bad_user_join_event.event_id},
|
||||
)
|
||||
|
||||
def test_cant_hide_direct_ancestors(self) -> None:
|
||||
"""
|
||||
If you send a message, you must be able to provide the direct
|
||||
prev_events that said event references.
|
||||
"""
|
||||
|
||||
async def post_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
data: Optional[JsonDict] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
# If it asks us for new missing events, give them NOTHING
|
||||
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
|
||||
return {"events": []}
|
||||
return {}
|
||||
|
||||
self.http_client.post_json = post_json
|
||||
|
||||
# Figure out what the forward extremities in the room are (the most recent
|
||||
# events that aren't tied into the DAG)
|
||||
forward_extremity_event_ids = self.get_success(
|
||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
||||
)
|
||||
|
||||
# Now lie about an event's prev_events
|
||||
lying_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": self.remote_bad_user_id,
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.message",
|
||||
"content": {"body": "hewwo?"},
|
||||
"auth_events": [],
|
||||
"prev_events": ["$missing_prev_event"]
|
||||
+ list(forward_extremity_event_ids),
|
||||
}
|
||||
),
|
||||
room_version=RoomVersions.V10,
|
||||
)
|
||||
|
||||
with LoggingContext("test-context"):
|
||||
failure = self.get_failure(
|
||||
self.federation_event_handler.on_receive_pdu(
|
||||
self.OTHER_SERVER_NAME, lying_event
|
||||
),
|
||||
FederationError,
|
||||
)
|
||||
|
||||
# on_receive_pdu should throw an error
|
||||
self.assertEqual(
|
||||
failure.value.args[0],
|
||||
(
|
||||
"ERROR 403: Your server isn't divulging details about prev_events "
|
||||
"referenced in this event."
|
||||
),
|
||||
)
|
||||
|
||||
# Make sure the invalid event isn't there
|
||||
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
|
||||
self.assertEqual(extrem, {self.remote_bad_user_join_event.event_id})
|
||||
|
||||
|
||||
class ServerACLsTestCase(unittest.TestCase):
|
||||
def test_blocked_server(self) -> None:
|
||||
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||
@@ -355,13 +519,76 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
# is probably sufficient to reassure that the bucket is updated.
|
||||
|
||||
|
||||
def _create_acl_event(content: JsonDict) -> EventBase:
|
||||
return make_event_from_dict(
|
||||
{
|
||||
"room_id": "!a:b",
|
||||
"event_id": "$a:b",
|
||||
"type": "m.room.server_acls",
|
||||
"sender": "@a:b",
|
||||
"content": content,
|
||||
class StripUnsignedFromEventsTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test to make sure that we handle the raw JSON events from federation carefully and
|
||||
strip anything that shouldn't be there.
|
||||
"""
|
||||
|
||||
def test_strip_unauthorized_unsigned_values(self) -> None:
|
||||
event1 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event1:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.member",
|
||||
"origin": "test.servx",
|
||||
"content": {"membership": "join"},
|
||||
"auth_events": [],
|
||||
"unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
|
||||
}
|
||||
)
|
||||
filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
|
||||
# Make sure unauthorized fields are stripped from unsigned
|
||||
self.assertNotIn("more warez", filtered_event.unsigned)
|
||||
|
||||
def test_strip_event_maintains_allowed_fields(self) -> None:
|
||||
event2 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event2:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.member",
|
||||
"origin": "test.servx",
|
||||
"auth_events": [],
|
||||
"content": {"membership": "join"},
|
||||
"unsigned": {
|
||||
"malicious garbage": "hackz",
|
||||
"more warez": "more hackz",
|
||||
"age": 14,
|
||||
"invite_room_state": [],
|
||||
},
|
||||
}
|
||||
|
||||
filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event2.unsigned)
|
||||
self.assertEqual(14, filtered_event2.unsigned["age"])
|
||||
self.assertNotIn("more warez", filtered_event2.unsigned)
|
||||
# Invite_room_state is allowed in events of type m.room.member
|
||||
self.assertIn("invite_room_state", filtered_event2.unsigned)
|
||||
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
|
||||
|
||||
def test_strip_event_removes_fields_based_on_event_type(self) -> None:
|
||||
event3 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event3:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.power_levels",
|
||||
"origin": "test.servx",
|
||||
"content": {},
|
||||
"auth_events": [],
|
||||
"unsigned": {
|
||||
"malicious garbage": "hackz",
|
||||
"more warez": "more hackz",
|
||||
"age": 14,
|
||||
"invite_room_state": [],
|
||||
},
|
||||
}
|
||||
filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event3.unsigned)
|
||||
# Invite_room_state field is only permitted in event type m.room.member
|
||||
self.assertNotIn("invite_room_state", filtered_event3.unsigned)
|
||||
self.assertNotIn("more warez", filtered_event3.unsigned)
|
||||
|
||||
@@ -375,7 +375,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
In this test, we pretend we are processing a "pulled" event via
|
||||
backfill. The pulled event succesfully processes and the backward
|
||||
extremeties are updated along with clearing out any failed pull attempts
|
||||
extremities are updated along with clearing out any failed pull attempts
|
||||
for those old extremities.
|
||||
|
||||
We check that we correctly cleared failed pull attempts of the
|
||||
|
||||
@@ -23,14 +23,21 @@ from typing import Optional, cast
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from parameterized import parameterized
|
||||
from signedjson.key import generate_signing_key
|
||||
from signedjson.key import (
|
||||
encode_verify_key_base64,
|
||||
generate_signing_key,
|
||||
get_verify_key,
|
||||
)
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.presence import UserDevicePresenceState, UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.api.room_versions import (
|
||||
RoomVersion,
|
||||
)
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.sender import FederationSender
|
||||
from synapse.handlers.presence import (
|
||||
BUSY_ONLINE_TIMEOUT,
|
||||
@@ -45,18 +52,24 @@ from synapse.handlers.presence import (
|
||||
handle_update,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import room
|
||||
from synapse.rest.client import login, room, sync
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [admin.register_servlets]
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
@@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
wheel_timer.insert.assert_not_called()
|
||||
|
||||
# `rc_presence` is set very high during unit tests to avoid ratelimiting
|
||||
# subtly impacting unrelated tests. We set the ratelimiting back to a
|
||||
# reasonable value for the tests specific to presence ratelimiting.
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None:
|
||||
"""
|
||||
Send a presence update, check that it went through, immediately send another one and
|
||||
check that it was ignored.
|
||||
"""
|
||||
self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True)
|
||||
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None:
|
||||
"""
|
||||
Send a presence update, check that it went through, advancing time a sufficient amount,
|
||||
send another presence update and check that it also worked.
|
||||
"""
|
||||
self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False)
|
||||
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def _test_ratelimit_offline_to_online_to_unavailable(
|
||||
self, ratelimited: bool
|
||||
) -> None:
|
||||
"""Test rate limit for presence updates sent with sync requests.
|
||||
|
||||
Args:
|
||||
ratelimited: Test rate limited case.
|
||||
"""
|
||||
wheel_timer = Mock()
|
||||
user_id = "@user:pass"
|
||||
now = 5000000
|
||||
sync_url = "/sync?access_token=%s&set_presence=%s"
|
||||
|
||||
# Register the user who syncs presence
|
||||
user_id = self.register_user("user", "pass")
|
||||
access_token = self.login("user", "pass")
|
||||
|
||||
# Get the handler (which kicks off a bunch of timers).
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
|
||||
# Ensure the user is initially offline.
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE, last_active_ts=now
|
||||
)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
prev_state,
|
||||
new_state,
|
||||
is_mine=True,
|
||||
wheel_timer=wheel_timer,
|
||||
now=now,
|
||||
persist=False,
|
||||
)
|
||||
|
||||
# Check that the user is offline.
|
||||
state = self.get_success(
|
||||
presence_handler.get_state(UserID.from_string(user_id))
|
||||
)
|
||||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||
|
||||
# Send sync request with set_presence=online.
|
||||
channel = self.make_request("GET", sync_url % (access_token, "online"))
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
# Assert the user is now online.
|
||||
state = self.get_success(
|
||||
presence_handler.get_state(UserID.from_string(user_id))
|
||||
)
|
||||
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||
|
||||
if not ratelimited:
|
||||
# Advance time a sufficient amount to avoid rate limiting.
|
||||
self.reactor.advance(30)
|
||||
|
||||
# Send another sync request with set_presence=unavailable.
|
||||
channel = self.make_request("GET", sync_url % (access_token, "unavailable"))
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
state = self.get_success(
|
||||
presence_handler.get_state(UserID.from_string(user_id))
|
||||
)
|
||||
|
||||
if ratelimited:
|
||||
# Assert the user is still online and presence update was ignored.
|
||||
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||
else:
|
||||
# Assert the user is now unavailable.
|
||||
self.assertEqual(state.state, PresenceState.UNAVAILABLE)
|
||||
|
||||
|
||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||
"""Tests different timers and that the timer does not change `status_msg` of user."""
|
||||
@@ -1825,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||
# self.event_builder_for_2.hostname = "test2"
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage_controllers = hs.get_storage_controllers()
|
||||
self.state = hs.get_state_handler()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
|
||||
@@ -1940,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
hostname = get_domain_from_id(user_id)
|
||||
|
||||
room_version = self.get_success(self.store.get_room_version_id(room_id))
|
||||
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||
|
||||
builder = EventBuilder(
|
||||
state=self.state,
|
||||
event_auth_handler=self._event_auth_handler,
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
hostname=hostname,
|
||||
signing_key=self.random_signing_key,
|
||||
room_version=KNOWN_ROOM_VERSIONS[room_version],
|
||||
room_id=room_id,
|
||||
type=EventTypes.Member,
|
||||
sender=user_id,
|
||||
state_key=user_id,
|
||||
content={"membership": Membership.JOIN},
|
||||
state_map = self.get_success(
|
||||
self.storage_controllers.state.get_current_state(room_id)
|
||||
)
|
||||
|
||||
prev_event_ids = self.get_success(
|
||||
self.store.get_latest_event_ids_in_room(room_id)
|
||||
# Figure out what the forward extremities in the room are (the most recent
|
||||
# events that aren't tied into the DAG)
|
||||
forward_extremity_event_ids = self.get_success(
|
||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
|
||||
event = self.get_success(
|
||||
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
|
||||
event = self.create_fake_event_from_remote_server(
|
||||
remote_server_name=hostname,
|
||||
event_dict={
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": user_id,
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
],
|
||||
"prev_events": list(forward_extremity_event_ids),
|
||||
},
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
|
||||
@@ -1970,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||
# Check that it was successfully persisted.
|
||||
self.get_success(self.store.get_event(event.event_id))
|
||||
self.get_success(self.store.get_event(event.event_id))
|
||||
|
||||
def create_fake_event_from_remote_server(
|
||||
self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
|
||||
) -> EventBase:
|
||||
"""
|
||||
This is similar to what `FederatingHomeserverTestCase` is doing but we don't
|
||||
need all of the extra baggage and we want to be able to create an event from
|
||||
many remote servers.
|
||||
"""
|
||||
|
||||
# poke the other server's signing key into the key store, so that we don't
|
||||
# make requests for it
|
||||
other_server_signature_key = generate_signing_key("test")
|
||||
verify_key = get_verify_key(other_server_signature_key)
|
||||
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.store_server_keys_response(
|
||||
remote_server_name,
|
||||
from_server=remote_server_name,
|
||||
ts_added_ms=self.clock.time_msec(),
|
||||
verify_keys={
|
||||
verify_key_id: FetchKeyResult(
|
||||
verify_key=verify_key,
|
||||
valid_until_ts=self.clock.time_msec() + 10000,
|
||||
),
|
||||
},
|
||||
response_json={
|
||||
"verify_keys": {
|
||||
verify_key_id: {"key": encode_verify_key_base64(verify_key)}
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
add_hashes_and_signatures(
|
||||
room_version=room_version,
|
||||
event_dict=event_dict,
|
||||
signature_name=remote_server_name,
|
||||
signing_key=other_server_signature_key,
|
||||
)
|
||||
event = make_event_from_dict(
|
||||
event_dict,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from http import HTTPStatus
|
||||
from typing import Collection, ContextManager, List, Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -347,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||
# the prev_events used when creating the join event, such that the ban does not
|
||||
# precede the join.
|
||||
with self._patch_get_latest_events([last_room_creation_event_id]):
|
||||
self.helper.join(room_id, eve, tok=eve_token)
|
||||
self.helper.join(
|
||||
room_id,
|
||||
eve,
|
||||
tok=eve_token,
|
||||
# Previously, this join would succeed but now we expect it to fail at
|
||||
# this point. The rest of the test is for the case when this used to
|
||||
# succeed.
|
||||
expect_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
# Eve makes a second, incremental sync.
|
||||
eve_incremental_sync_after_join: SyncResult = self.get_success(
|
||||
|
||||
@@ -22,14 +22,26 @@ import logging
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from netaddr import IPSet
|
||||
from signedjson.key import (
|
||||
encode_verify_key_base64,
|
||||
generate_signing_key,
|
||||
get_verify_key,
|
||||
)
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.handlers.typing import TypingWriterHandler
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import get_clock
|
||||
@@ -63,6 +75,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
ip_blocklist=IPSet(),
|
||||
)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
def test_send_event_single_sender(self) -> None:
|
||||
"""Test that using a single federation sender worker correctly sends a
|
||||
new event.
|
||||
@@ -243,35 +258,92 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
self.assertTrue(sent_on_1)
|
||||
self.assertTrue(sent_on_2)
|
||||
|
||||
def create_fake_event_from_remote_server(
|
||||
self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
|
||||
) -> EventBase:
|
||||
"""
|
||||
This is similar to what `FederatingHomeserverTestCase` is doing but we don't
|
||||
need all of the extra baggage and we want to be able to create an event from
|
||||
many remote servers.
|
||||
"""
|
||||
|
||||
# poke the other server's signing key into the key store, so that we don't
|
||||
# make requests for it
|
||||
other_server_signature_key = generate_signing_key("test")
|
||||
verify_key = get_verify_key(other_server_signature_key)
|
||||
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.store_server_keys_response(
|
||||
remote_server_name,
|
||||
from_server=remote_server_name,
|
||||
ts_added_ms=self.clock.time_msec(),
|
||||
verify_keys={
|
||||
verify_key_id: FetchKeyResult(
|
||||
verify_key=verify_key,
|
||||
valid_until_ts=self.clock.time_msec() + 10000,
|
||||
),
|
||||
},
|
||||
response_json={
|
||||
"verify_keys": {
|
||||
verify_key_id: {"key": encode_verify_key_base64(verify_key)}
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
add_hashes_and_signatures(
|
||||
room_version=room_version,
|
||||
event_dict=event_dict,
|
||||
signature_name=remote_server_name,
|
||||
signing_key=other_server_signature_key,
|
||||
)
|
||||
event = make_event_from_dict(
|
||||
event_dict,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
def create_room_with_remote_server(
|
||||
self, user: str, token: str, remote_server: str = "other_server"
|
||||
) -> str:
|
||||
room = self.helper.create_room_as(user, tok=token)
|
||||
room_id = self.helper.create_room_as(user, tok=token)
|
||||
store = self.hs.get_datastores().main
|
||||
federation = self.hs.get_federation_event_handler()
|
||||
|
||||
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
|
||||
room_version = self.get_success(store.get_room_version(room))
|
||||
room_version = self.get_success(store.get_room_version(room_id))
|
||||
|
||||
factory = EventBuilderFactory(self.hs)
|
||||
factory.hostname = remote_server
|
||||
state_map = self.get_success(
|
||||
self.storage_controllers.state.get_current_state(room_id)
|
||||
)
|
||||
|
||||
# Figure out what the forward extremities in the room are (the most recent
|
||||
# events that aren't tied into the DAG)
|
||||
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room_id))
|
||||
|
||||
user_id = UserID("user", remote_server).to_string()
|
||||
|
||||
event_dict = {
|
||||
"type": EventTypes.Member,
|
||||
"state_key": user_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"sender": user_id,
|
||||
"room_id": room,
|
||||
}
|
||||
|
||||
builder = factory.for_room_version(room_version, event_dict)
|
||||
join_event = self.get_success(
|
||||
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
|
||||
join_event = self.create_fake_event_from_remote_server(
|
||||
remote_server_name=remote_server,
|
||||
event_dict={
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"type": EventTypes.Member,
|
||||
"state_key": user_id,
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
],
|
||||
"prev_events": list(prev_event_ids),
|
||||
},
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
self.get_success(federation.on_send_membership_event(remote_server, join_event))
|
||||
self.replicate()
|
||||
|
||||
return room
|
||||
return room_id
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from synapse.rest.client import auth_issuer
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||
from tests.utils import HAS_AUTHLIB
|
||||
|
||||
ISSUER = "https://account.example.com/"
|
||||
|
||||
|
||||
class AuthIssuerTestCase(HomeserverTestCase):
|
||||
servlets = [
|
||||
auth_issuer.register_servlets,
|
||||
]
|
||||
|
||||
def test_returns_404_when_msc3861_disabled(self) -> None:
|
||||
# Make an unauthenticated request for the discovery info.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": ISSUER,
|
||||
"client_id": "David Lister",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "Who shot Mister Burns?",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_returns_issuer_when_oidc_enabled(self) -> None:
|
||||
# Patch the HTTP client to return the issuer metadata
|
||||
req_mock = AsyncMock(return_value={"issuer": ISSUER})
|
||||
self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"issuer": ISSUER})
|
||||
|
||||
req_mock.assert_called_with(
|
||||
"https://account.example.com/.well-known/openid-configuration"
|
||||
)
|
||||
req_mock.reset_mock()
|
||||
|
||||
# Second call it should use the cached value
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"issuer": ISSUER})
|
||||
req_mock.assert_not_called()
|
||||
140
tests/rest/client/test_auth_metadata.py
Normal file
140
tests/rest/client/test_auth_metadata.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
# Copyright (C) 2023-2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from synapse.rest.client import auth_metadata
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||
from tests.utils import HAS_AUTHLIB
|
||||
|
||||
ISSUER = "https://account.example.com/"
|
||||
|
||||
|
||||
class AuthIssuerTestCase(HomeserverTestCase):
|
||||
servlets = [
|
||||
auth_metadata.register_servlets,
|
||||
]
|
||||
|
||||
def test_returns_404_when_msc3861_disabled(self) -> None:
|
||||
# Make an unauthenticated request for the discovery info.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": ISSUER,
|
||||
"client_id": "David Lister",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "Who shot Mister Burns?",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_returns_issuer_when_oidc_enabled(self) -> None:
|
||||
# Patch the HTTP client to return the issuer metadata
|
||||
req_mock = AsyncMock(return_value={"issuer": ISSUER})
|
||||
self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"issuer": ISSUER})
|
||||
|
||||
req_mock.assert_called_with(
|
||||
"https://account.example.com/.well-known/openid-configuration"
|
||||
)
|
||||
req_mock.reset_mock()
|
||||
|
||||
# Second call it should use the cached value
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"issuer": ISSUER})
|
||||
req_mock.assert_not_called()
|
||||
|
||||
|
||||
class AuthMetadataTestCase(HomeserverTestCase):
|
||||
servlets = [
|
||||
auth_metadata.register_servlets,
|
||||
]
|
||||
|
||||
def test_returns_404_when_msc3861_disabled(self) -> None:
|
||||
# Make an unauthenticated request for the discovery info.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": ISSUER,
|
||||
"client_id": "David Lister",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "Who shot Mister Burns?",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_returns_issuer_when_oidc_enabled(self) -> None:
|
||||
# Patch the HTTP client to return the issuer metadata
|
||||
req_mock = AsyncMock(
|
||||
return_value={
|
||||
"issuer": ISSUER,
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
)
|
||||
self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
"issuer": ISSUER,
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
},
|
||||
)
|
||||
@@ -142,6 +142,50 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"enable_set_displayname": False,
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
|
||||
self,
|
||||
) -> None:
|
||||
"""Test if set displayname is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
|
||||
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
|
||||
self.assertEqual(
|
||||
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
|
||||
["displayname"],
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"enable_set_avatar_url": False,
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
|
||||
"""Test if set avatar_url is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
||||
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
|
||||
self.assertEqual(
|
||||
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
|
||||
["avatar_url"],
|
||||
)
|
||||
|
||||
@override_config({"enable_3pid_changes": False})
|
||||
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
|
||||
"""Test if change 3pid is disabled that the server responds it."""
|
||||
|
||||
@@ -29,6 +29,7 @@ from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
@@ -95,3 +96,54 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(self.presence_handler.set_state.call_count, 0)
|
||||
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_put_presence_over_ratelimit(self) -> None:
|
||||
"""
|
||||
Multiple PUTs to the status endpoint without sufficient delay will be rate limited.
|
||||
"""
|
||||
self.hs.config.server.presence_enabled = True
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS)
|
||||
self.assertEqual(self.presence_handler.set_state.call_count, 1)
|
||||
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_put_presence_within_ratelimit(self) -> None:
|
||||
"""
|
||||
Multiple PUTs to the status endpoint with sufficient delay should all call set_state.
|
||||
"""
|
||||
self.hs.config.server.presence_enabled = True
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
|
||||
# Advance time a sufficient amount to avoid rate limiting.
|
||||
self.reactor.advance(30)
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(self.presence_handler.set_state.call_count, 2)
|
||||
|
||||
@@ -25,16 +25,20 @@ import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, profile, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
@@ -480,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# The client requested ?propagate=true, so it should have happened.
|
||||
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_missing_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_missing_custom_field_invalid_field_name(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_custom_field_rejects_bad_username(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"custom_field": "test"})
|
||||
|
||||
# Overwriting the field should work.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "new_Value"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
|
||||
|
||||
# Deleting the field should work.
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_non_string(self) -> None:
|
||||
"""Non-string fields are supported for custom fields."""
|
||||
fields = {
|
||||
"bool_field": True,
|
||||
"array_field": ["test"],
|
||||
"object_field": {"test": "test"},
|
||||
"numeric_field": 1,
|
||||
"null_field": None,
|
||||
}
|
||||
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: value},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
|
||||
|
||||
# Check getting individual fields works.
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {key: value})
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_noauth(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_size(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field name that is too long should get a 400 error.
|
||||
"""
|
||||
# Key is missing.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
|
||||
content={"": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Single key is too large.
|
||||
key = "c" * 500
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
# Key doesn't match body.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"diff_key": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_profile_too_long(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field that would push the overall profile too large.
|
||||
"""
|
||||
# Get right to the boundary:
|
||||
# len("displayname") + len("owner") + 5 = 21 for the displayname
|
||||
# 1 + 65498 + 5 for key "a" = 65504
|
||||
# 2 braces, 1 comma
|
||||
# 3 + 21 + 65498 = 65522 < 65536.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "a" * 65498},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Get the entire profile.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
canonical_json = encode_canonical_json(channel.json_body)
|
||||
# 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
|
||||
# Be one below that so we can prove we're at the boundary.
|
||||
self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
|
||||
|
||||
# Postgres stores JSONB with whitespace, while SQLite doesn't.
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
ADDITIONAL_CHARS = 0
|
||||
else:
|
||||
ADDITIONAL_CHARS = 1
|
||||
|
||||
# The next one should fail, note the value has a (JSON) length of 2.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "1" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
# Setting an avatar or (longer) display name should not work.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/displayname",
|
||||
content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://foo/bar"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
# Removing a single byte should work.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Finally, setting a field that already exists to a value that is <= in length should work.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: ""},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_displayname(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
|
||||
content={"displayname": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
displayname = self._get_displayname()
|
||||
self.assertEqual(displayname, "test")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_avatar_url(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/good"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
avatar_url = self._get_avatar_url()
|
||||
self.assertEqual(avatar_url, "mxc://test/good")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_other(self) -> None:
|
||||
"""Setting someone else's profile field should fail"""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""Stores metadata about files in the database.
|
||||
|
||||
|
||||
@@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
assert channel.resource_usage is not None
|
||||
self.assertEqual(33, channel.resource_usage.db_txn_count)
|
||||
self.assertEqual(34, channel.resource_usage.db_txn_count)
|
||||
|
||||
def test_post_room_initial_state(self) -> None:
|
||||
# POST with initial_state config key, expect new room id
|
||||
@@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
assert channel.resource_usage is not None
|
||||
self.assertEqual(35, channel.resource_usage.db_txn_count)
|
||||
self.assertEqual(36, channel.resource_usage.db_txn_count)
|
||||
|
||||
def test_post_room_visibility_key(self) -> None:
|
||||
# POST with visibility config key, expect new room id
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
from typing import Collection, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import FederationError
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.federation.federation_base import event_from_pdu_json
|
||||
from synapse.handlers.device import DeviceListUpdater
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock()
|
||||
return self.setup_test_homeserver(federation_http_client=self.http_client)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
user_id = UserID("us", "test")
|
||||
our_user = create_requester(user_id)
|
||||
room_creator = self.hs.get_room_creation_handler()
|
||||
self.room_id = self.get_success(
|
||||
room_creator.create_room(
|
||||
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
|
||||
)
|
||||
)[0]
|
||||
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
# Figure out what the most recent event is
|
||||
most_recent = next(
|
||||
iter(
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(
|
||||
self.room_id
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
join_event = make_event_from_dict(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$join:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.member",
|
||||
"origin": "test.servx",
|
||||
"content": {"membership": "join"},
|
||||
"auth_events": [],
|
||||
"prev_state": [(most_recent, {})],
|
||||
"prev_events": [(most_recent, {})],
|
||||
}
|
||||
)
|
||||
|
||||
self.handler = self.hs.get_federation_handler()
|
||||
federation_event_handler = self.hs.get_federation_event_handler()
|
||||
|
||||
async def _check_event_auth(
|
||||
origin: Optional[str], event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign]
|
||||
self.client = self.hs.get_federation_client()
|
||||
|
||||
async def _check_sigs_and_hash_for_pulled_events_and_fetch(
|
||||
dest: str, pdus: Collection[EventBase], room_version: RoomVersion
|
||||
) -> List[EventBase]:
|
||||
return list(pdus)
|
||||
|
||||
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign]
|
||||
_check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
|
||||
)
|
||||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
federation_event_handler.on_receive_pdu("test.serv", join_event)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Make sure we actually joined the room
|
||||
self.assertEqual(
|
||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
|
||||
{"$join:test.serv"},
|
||||
)
|
||||
|
||||
def test_cant_hide_direct_ancestors(self) -> None:
|
||||
"""
|
||||
If you send a message, you must be able to provide the direct
|
||||
prev_events that said event references.
|
||||
"""
|
||||
|
||||
async def post_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
data: Optional[JsonDict] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
# If it asks us for new missing events, give them NOTHING
|
||||
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
|
||||
return {"events": []}
|
||||
return {}
|
||||
|
||||
self.http_client.post_json = post_json
|
||||
|
||||
# Figure out what the most recent event is
|
||||
most_recent = next(
|
||||
iter(
|
||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
|
||||
)
|
||||
)
|
||||
|
||||
# Now lie about an event
|
||||
lying_event = make_event_from_dict(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": "@baduser:test.serv",
|
||||
"event_id": "one:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.message",
|
||||
"origin": "test.serv",
|
||||
"content": {"body": "hewwo?"},
|
||||
"auth_events": [],
|
||||
"prev_events": [("two:test.serv", {}), (most_recent, {})],
|
||||
}
|
||||
)
|
||||
|
||||
federation_event_handler = self.hs.get_federation_event_handler()
|
||||
with LoggingContext("test-context"):
|
||||
failure = self.get_failure(
|
||||
federation_event_handler.on_receive_pdu("test.serv", lying_event),
|
||||
FederationError,
|
||||
)
|
||||
|
||||
# on_receive_pdu should throw an error
|
||||
self.assertEqual(
|
||||
failure.value.args[0],
|
||||
(
|
||||
"ERROR 403: Your server isn't divulging details about prev_events "
|
||||
"referenced in this event."
|
||||
),
|
||||
)
|
||||
|
||||
# Make sure the invalid event isn't there
|
||||
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
|
||||
self.assertEqual(extrem, {"$join:test.serv"})
|
||||
|
||||
def test_retry_device_list_resync(self) -> None:
|
||||
"""Tests that device lists are marked as stale if they couldn't be synced, and
|
||||
that stale device lists are retried periodically.
|
||||
"""
|
||||
remote_user_id = "@john:test_remote"
|
||||
remote_origin = "test_remote"
|
||||
|
||||
# Track the number of attempts to resync the user's device list.
|
||||
self.resync_attempts = 0
|
||||
|
||||
# When this function is called, increment the number of resync attempts (only if
|
||||
# we're querying devices for the right user ID), then raise a
|
||||
# NotRetryingDestination error to fail the resync gracefully.
|
||||
def query_user_devices(
|
||||
destination: str, user_id: str, timeout: int = 30000
|
||||
) -> JsonDict:
|
||||
if user_id == remote_user_id:
|
||||
self.resync_attempts += 1
|
||||
|
||||
raise NotRetryingDestination(0, 0, destination)
|
||||
|
||||
# Register the mock on the federation client.
|
||||
federation_client = self.hs.get_federation_client()
|
||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign]
|
||||
|
||||
# Register a mock on the store so that the incoming update doesn't fail because
|
||||
# we don't share a room with the user.
|
||||
store = self.hs.get_datastores().main
|
||||
store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
|
||||
|
||||
# Manually inject a fake device list update. We need this update to include at
|
||||
# least one prev_id so that the user's device list will need to be retried.
|
||||
device_list_updater = self.hs.get_device_handler().device_list_updater
|
||||
assert isinstance(device_list_updater, DeviceListUpdater)
|
||||
self.get_success(
|
||||
device_list_updater.incoming_device_list_update(
|
||||
origin=remote_origin,
|
||||
edu_content={
|
||||
"deleted": False,
|
||||
"device_display_name": "Mobile",
|
||||
"device_id": "QBUAZIFURK",
|
||||
"prev_id": [5],
|
||||
"stream_id": 6,
|
||||
"user_id": remote_user_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Check that there was one resync attempt.
|
||||
self.assertEqual(self.resync_attempts, 1)
|
||||
|
||||
# Check that the resync attempt failed and caused the user's device list to be
|
||||
# marked as stale.
|
||||
need_resync = self.get_success(
|
||||
store.get_user_ids_requiring_device_list_resync()
|
||||
)
|
||||
self.assertIn(remote_user_id, need_resync)
|
||||
|
||||
# Check that waiting for 30 seconds caused Synapse to retry resyncing the device
|
||||
# list.
|
||||
self.reactor.advance(30)
|
||||
self.assertEqual(self.resync_attempts, 2)
|
||||
|
||||
def test_cross_signing_keys_retry(self) -> None:
|
||||
"""Tests that resyncing a device list correctly processes cross-signing keys from
|
||||
the remote server.
|
||||
"""
|
||||
remote_user_id = "@john:test_remote"
|
||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||
|
||||
# Register mock device list retrieval on the federation client.
|
||||
federation_client = self.hs.get_federation_client()
|
||||
federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"user_id": remote_user_id,
|
||||
"stream_id": 1,
|
||||
"devices": [],
|
||||
"master_key": {
|
||||
"user_id": remote_user_id,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + remote_master_key: remote_master_key},
|
||||
},
|
||||
"self_signing_key": {
|
||||
"user_id": remote_user_id,
|
||||
"usage": ["self_signing"],
|
||||
"keys": {
|
||||
"ed25519:" + remote_self_signing_key: remote_self_signing_key
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Resync the device list.
|
||||
device_handler = self.hs.get_device_handler()
|
||||
self.get_success(
|
||||
device_handler.device_list_updater.multi_user_device_resync(
|
||||
[remote_user_id]
|
||||
),
|
||||
)
|
||||
|
||||
# Retrieve the cross-signing keys for this user.
|
||||
keys = self.get_success(
|
||||
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
|
||||
)
|
||||
self.assertIn(remote_user_id, keys)
|
||||
key = keys[remote_user_id]
|
||||
assert key is not None
|
||||
|
||||
# Check that the master key is the one returned by the mock.
|
||||
master_key = key["master"]
|
||||
self.assertEqual(len(master_key["keys"]), 1)
|
||||
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
|
||||
self.assertTrue(remote_master_key in master_key["keys"].values())
|
||||
|
||||
# Check that the self-signing key is the one returned by the mock.
|
||||
self_signing_key = key["self_signing"]
|
||||
self.assertEqual(len(self_signing_key["keys"]), 1)
|
||||
self.assertTrue(
|
||||
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
|
||||
)
|
||||
self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
|
||||
|
||||
|
||||
class StripUnsignedFromEventsTestCase(unittest.TestCase):
|
||||
def test_strip_unauthorized_unsigned_values(self) -> None:
|
||||
event1 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event1:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.member",
|
||||
"origin": "test.servx",
|
||||
"content": {"membership": "join"},
|
||||
"auth_events": [],
|
||||
"unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
|
||||
}
|
||||
filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
|
||||
# Make sure unauthorized fields are stripped from unsigned
|
||||
self.assertNotIn("more warez", filtered_event.unsigned)
|
||||
|
||||
def test_strip_event_maintains_allowed_fields(self) -> None:
|
||||
event2 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event2:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.member",
|
||||
"origin": "test.servx",
|
||||
"auth_events": [],
|
||||
"content": {"membership": "join"},
|
||||
"unsigned": {
|
||||
"malicious garbage": "hackz",
|
||||
"more warez": "more hackz",
|
||||
"age": 14,
|
||||
"invite_room_state": [],
|
||||
},
|
||||
}
|
||||
|
||||
filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event2.unsigned)
|
||||
self.assertEqual(14, filtered_event2.unsigned["age"])
|
||||
self.assertNotIn("more warez", filtered_event2.unsigned)
|
||||
# Invite_room_state is allowed in events of type m.room.member
|
||||
self.assertIn("invite_room_state", filtered_event2.unsigned)
|
||||
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
|
||||
|
||||
def test_strip_event_removes_fields_based_on_event_type(self) -> None:
|
||||
event3 = {
|
||||
"sender": "@baduser:test.serv",
|
||||
"state_key": "@baduser:test.serv",
|
||||
"event_id": "$event3:test.serv",
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": "m.room.power_levels",
|
||||
"origin": "test.servx",
|
||||
"content": {},
|
||||
"auth_events": [],
|
||||
"unsigned": {
|
||||
"malicious garbage": "hackz",
|
||||
"more warez": "more hackz",
|
||||
"age": 14,
|
||||
"invite_room_state": [],
|
||||
},
|
||||
}
|
||||
filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event3.unsigned)
|
||||
# Invite_room_state field is only permitted in event type m.room.member
|
||||
self.assertNotIn("invite_room_state", filtered_event3.unsigned)
|
||||
self.assertNotIn("more warez", filtered_event3.unsigned)
|
||||
@@ -20,7 +20,11 @@
|
||||
#
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.stringutils import assert_valid_client_secret, base62_encode
|
||||
from synapse.util.stringutils import (
|
||||
assert_valid_client_secret,
|
||||
base62_encode,
|
||||
is_namedspaced_grammar,
|
||||
)
|
||||
|
||||
from .. import unittest
|
||||
|
||||
@@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
|
||||
self.assertEqual("10", base62_encode(62))
|
||||
self.assertEqual("1c", base62_encode(100))
|
||||
self.assertEqual("001c", base62_encode(100, minwidth=4))
|
||||
|
||||
def test_namespaced_identifier(self) -> None:
|
||||
self.assertTrue(is_namedspaced_grammar("test"))
|
||||
self.assertTrue(is_namedspaced_grammar("m.test"))
|
||||
self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
|
||||
self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
|
||||
self.assertTrue(is_namedspaced_grammar("test"))
|
||||
self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
|
||||
|
||||
# Must start with letter.
|
||||
self.assertFalse(is_namedspaced_grammar("1test"))
|
||||
self.assertFalse(is_namedspaced_grammar("-test"))
|
||||
self.assertFalse(is_namedspaced_grammar("_test"))
|
||||
self.assertFalse(is_namedspaced_grammar(".test"))
|
||||
|
||||
# Must contain only a-z, 0-9, -, _, ..
|
||||
self.assertFalse(is_namedspaced_grammar("test/"))
|
||||
self.assertFalse(is_namedspaced_grammar('test"'))
|
||||
self.assertFalse(is_namedspaced_grammar("testö"))
|
||||
|
||||
# Must be < 255 characters.
|
||||
self.assertFalse(is_namedspaced_grammar("t" * 256))
|
||||
|
||||
@@ -200,6 +200,7 @@ def default_config(
|
||||
"per_user": {"per_second": 10000, "burst_count": 10000},
|
||||
},
|
||||
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
|
||||
"rc_presence": {"per_user": {"per_second": 10000, "burst_count": 10000}},
|
||||
"saml2_enabled": False,
|
||||
"public_baseurl": None,
|
||||
"default_identity_server": None,
|
||||
@@ -399,11 +400,24 @@ class TestTimeout(Exception):
|
||||
|
||||
|
||||
class test_timeout:
|
||||
"""
|
||||
FIXME: This implementation is not robust against other code tight-looping and
|
||||
preventing the signals propagating and timing out the test. You may need to add
|
||||
`time.sleep(0.1)` to your code in order to allow this timeout to work correctly.
|
||||
|
||||
```py
|
||||
with test_timeout(3):
|
||||
while True:
|
||||
my_checking_func()
|
||||
time.sleep(0.1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
|
||||
if error_message is None:
|
||||
error_message = "test timed out after {}s.".format(seconds)
|
||||
self.error_message = f"Test timed out after {seconds}s"
|
||||
if error_message is not None:
|
||||
self.error_message += f": {error_message}"
|
||||
self.seconds = seconds
|
||||
self.error_message = error_message
|
||||
|
||||
def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
|
||||
raise TestTimeout(self.error_message)
|
||||
|
||||
Reference in New Issue
Block a user