Compare commits

...

10 Commits

Author SHA1 Message Date
Andrew Morgan
3bd26733d7 Add tests
We add a series of tests that check whether device list sending works
across a variety of possible configurations.
2022-03-10 15:50:58 +00:00
Andrew Morgan
90fa2026ba fix tests for device lists 2022-03-10 15:50:58 +00:00
Andrew Morgan
55ac419b63 Add device lists to AS txns, thread thru the AS scheduler methods
Here we implement code that adds support for device list changes all
the way from our enqueue_for_appservice method down to where AS
transactions are actually built and sent out.
2022-03-10 15:50:58 +00:00
Andrew Morgan
047db4da1c Use get_users_whose_devices_changed to pull device list changes for given AS
When a new device list change occurs, we're now:

1. For each appservice, checking the last device list stream key that was
   processed up until.
2. Getting any users with changed device list between the last device list
   stream key and the stream key of the triggering update.
3. Filtering out those users based on those that are actually relevant
   to this application service.
4. Passing those changes to enqueue_for_appservice and saving the device list
   stream key that we've just processed up to for later reference.
2022-03-10 15:50:58 +00:00
Andrew Morgan
88c4e7369d Switch DeviceLists to containing Sets, which allows item deletes
In the next commit, we'll be merging one DeviceList into another. This
will require the ability to remove items by value, which Collection does
not provide, while a mutable structure such as Set does. Set was chosen to
to remove duplicate user IDs.
2022-03-10 15:50:58 +00:00
Andrew Morgan
a77f35144f Move DeviceLists type to synapse.types
So that we can use it elsewhere.
2022-03-10 15:50:58 +00:00
Andrew Morgan
1671f8772d Add migration delta to track device_list stream id per appservice 2022-03-10 15:50:58 +00:00
Andrew Morgan
b4aad3604a Add to_key arg, user_ids optional for get_users_whose_devices_changed
to_key prevents overlapping bounds when pulling out device list updates.

user_ids needs to be optional as we won't have a list of user_ids to
filter with when calling this function from a triggered device_list
change.
2022-03-10 15:50:58 +00:00
Andrew Morgan
51be04b918 Guard processing device list updates with experimental option 2022-03-10 15:50:58 +00:00
Andrew Morgan
4b6711803d Set min application service stream_id to 1
Factored out into #12193.
2022-03-09 17:27:52 +00:00
14 changed files with 469 additions and 83 deletions

View File

@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +23,7 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.types import DeviceLists, GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@@ -400,6 +401,7 @@ class AppServiceTransaction:
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
):
self.service = service
self.id = id
@@ -408,6 +410,7 @@ class AppServiceTransaction:
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
self.device_list_summary = device_list_summary
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@@ -424,6 +427,7 @@ class AppServiceTransaction:
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
device_list_summary=self.device_list_summary,
txn_id=self.id,
)

View File

@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ from synapse.appservice import (
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
txn_id: Optional[int] = None,
) -> bool:
"""
@@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
@@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
if device_list_summary:
body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed),
"left": list(device_list_summary.left),
}
try:
await self.put_json(

View File

@@ -72,7 +72,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
from synapse.types import DeviceLists, JsonDict
from synapse.util import Clock
if TYPE_CHECKING:
@@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
device_list_summary: Optional[DeviceLists] = None,
) -> None:
"""
Enqueue some data to be sent off to an application service.
@@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
device_list_summary: A summary of users that the application service either needs
to refresh the device lists of, or those that the application service need no
longer track the device lists of.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages:
if (
not events
and not ephemeral
and not to_device_messages
and not device_list_summary
):
return
if events:
@@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
if device_list_summary:
self.queuer.queued_device_list_summaries.setdefault(
appservice.id, []
).append(device_list_summary)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
@@ -169,6 +182,8 @@ class _ServiceQueuer:
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
@@ -212,7 +227,40 @@ class _ServiceQueuer:
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send:
# Consolidate any pending device list summaries into a single, up-to-date
# summary.
# Note: this code assumes that in a single DeviceLists, a user will
# never be in both "changed" and "left" sets.
device_list_summary = DeviceLists()
while self.queued_device_list_summaries.get(service.id, []):
# Pop a summary off the front of the queue
summary = self.queued_device_list_summaries[service.id].pop(0)
# For every user in the incoming "changed" set:
# * Remove them from the existing "left" set if necessary
# (as we need to start tracking them again)
# * Add them to the existing "changed" set if necessary.
for user_id in summary.changed:
if user_id in device_list_summary.left:
device_list_summary.left.remove(user_id)
device_list_summary.changed.add(user_id)
# For every user in the incoming "left" set:
# * Remove them from the existing "changed" set if necessary
# (we no longer need to track them)
# * Add them to the existing "left" set if necessary.
for user_id in summary.left:
if user_id in device_list_summary.changed:
device_list_summary.changed.remove(user_id)
device_list_summary.left.add(user_id)
if (
not events
and not ephemeral
and not to_device_messages_to_send
# Note that DeviceLists implements __bool__
and not device_list_summary
):
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
@@ -240,6 +288,7 @@ class _ServiceQueuer:
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
device_list_summary,
)
except Exception:
logger.exception("AS request failed")
@@ -322,6 +371,7 @@ class _TransactionController:
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceLists] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@@ -336,6 +386,7 @@ class _TransactionController:
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@@ -345,6 +396,7 @@ class _TransactionController:
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceLists(),
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View File

@@ -170,6 +170,7 @@ def _load_appservice(
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
# - device list changes
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(

View File

@@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
# The portion of MSC3202 related to transaction extensions:
# sending device list changes, one-time key counts and fallback key
# usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)

View File

@@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.types import DeviceLists, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
@@ -58,6 +58,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self._msc3202_transaction_extensions_enabled = (
hs.config.experimental.msc3202_transaction_extensions
)
self.current_max = 0
self.is_processing = False
@@ -204,9 +207,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key",
"to_device_key" or "device_list_key". Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@@ -230,6 +233,7 @@ class ApplicationServicesHandler:
"receipt_key",
"presence_key",
"to_device_key",
"device_list_key",
):
return
@@ -253,15 +257,37 @@ class ApplicationServicesHandler:
):
return
# Ignore device lists if the feature flag is not enabled
if (
stream_key == "device_list_key"
and not self._msc3202_transaction_extensions_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = self.store.get_app_services()
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
for service in services
# Different stream keys require different support booleans
if (
stream_key
in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
)
and service.supports_ephemeral
)
or (
stream_key == "device_list_key"
and service.msc3202_transaction_extensions
)
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
@@ -336,6 +362,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
elif stream_key == "device_list_key":
device_list_summary = await self._get_device_list_summary(
service, new_token
)
if device_list_summary:
self.scheduler.enqueue_for_appservice(
service, device_list_summary=device_list_summary
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "device_list", new_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@@ -542,6 +582,98 @@ class ApplicationServicesHandler:
return message_payload
async def _get_device_list_summary(
self,
appservice: ApplicationService,
new_key: int,
) -> DeviceLists:
"""
Retrieve a list of users who have changed their device lists.
Args:
appservice: The application service to retrieve device list changes for.
new_key: The stream key of the device list change that triggered this method call.
Returns:
A set of device list updates, comprised of users that the appservices needs to:
* resync the device list of, and
* stop tracking the device list of.
"""
# Fetch the last successfully processed device list update stream ID
# for this appservice.
from_key = await self.store.get_type_stream_id_for_appservice(
appservice, "device_list"
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(
from_key, user_ids=None, to_key=new_key
)
)
# Filter out any users the application service is not interested in
#
# For each user who changed their device list, we want to check whether this
# appservice would be interested in the change.
filtered_users_with_changed_device_lists = {
user_id
for user_id in users_with_changed_device_lists
if await self._is_appservice_interested_in_device_lists_of_user(
appservice, user_id
)
}
# Create a summary of "changed" and "left" users.
# TODO: Calculate "left" users.
device_list_summary = DeviceLists(
changed=filtered_users_with_changed_device_lists
)
return device_list_summary
async def _is_appservice_interested_in_device_lists_of_user(
self,
appservice: ApplicationService,
user_id: str,
) -> bool:
"""
Returns whether a given application service is interested in the device list
updates of a given user.
The application service is interested in the user's device list updates if any
of the following are true:
* The user is the appservice's sender localpart user.
* The user is in the appservice's user namespace.
* At least one member of one room that the user is a part of is in the
appservice's user namespace.
* The appservice is explicitly (via room ID or alias) interested in at
least one room that the user is in.
Args:
appservice: The application service to gauge interest of.
user_id: The ID of the user whose device list interest is in question.
Returns:
True if the application service is interested in the user's device lists, False
otherwise.
"""
# This method checks against both the sender localpart user as well as if the
# user is in the appservice's user namespace.
if appservice.is_interested_in_user(user_id):
return True
# FIXME: This is quite an expensive check. This method is called per device
# list change.
room_ids = await self.store.get_rooms_for_user(user_id)
for room_id in room_ids:
# This method covers checking room members for appservice interest as well as
# room ID and alias checks.
if await appservice.is_interested_in_room(room_id, self.store):
return True
return False
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.

View File

@@ -13,17 +13,7 @@
# limitations under the License.
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
FrozenSet,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -41,6 +31,7 @@ from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
DeviceLists,
JsonDict,
MutableStateMap,
Requester,
@@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: List of user_ids whose devices may have changed
left: List of user_ids whose devices we no longer track
"""
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@@ -1380,7 +1356,7 @@ class SyncHandler:
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
return DeviceLists(changed=[], left=[])
return DeviceLists()
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"

View File

@@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict
from synapse.types import DeviceLists, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
@@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys
# are not yet populated for catch-up transactions.
# TODO: to-device messages, one-time key counts, device list summaries and unused
# fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence", "to_device"):
if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -446,7 +450,7 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
return 0
return 1
else:
return int(last_stream_id[0])
@@ -457,7 +461,7 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if stream_type not in ("read_receipt", "presence", "to_device"):
if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)

View File

@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
from_key: The device lists stream token
user_ids: The user IDs to query for devices.
from_key: The minimum device lists stream token to query device list changes for,
exclusive.
user_ids: If provided, only check if these users have changed their device lists.
Otherwise changes from all users are returned.
to_key: The maximum device lists stream token to query device list changes for,
inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
The set of user_ids whose devices have changed since `from_key` (exclusive)
until `to_key` (inclusive).
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if not to_check:
if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
sql = """
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args += [to_key]
sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
WHERE {stream_id_where_clause}
AND
"""
for chunk in batch_iter(to_check, 100):
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
sql_args += args
txn.execute(sql + clause, sql_args)
changes.update(user_id for user_id, in txn)
return changes

View File

@@ -0,0 +1,18 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what device list changes stream id that this application
-- service has been caught up to.
ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;

View File

@@ -25,6 +25,7 @@ from typing import (
Match,
MutableMapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
@@ -743,6 +744,26 @@ class ReadReceipt:
data: JsonDict
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: user_ids whose devices may have changed
left: user_ids whose devices we no longer track
"""
# We need to use a factory here, otherwise `set` is not evaluated at
# object instantiation, but instead at class definition instantiation.
# The latter happening only once, thus always giving you the same sets
# across multiple DeviceLists instances.
# Also see: don't define mutable default arguments.
changed: Set[str] = attr.ib(factory=set)
left: Set[str] = attr.ib(factory=set)
def __bool__(self) -> bool:
return bool(self.changed or self.left)
def get_verify_key_from_cross_signing_key(key_info):
"""Get the key ID and signedjson verify key from a cross-signing key dict

View File

@@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
)
from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer
from synapse.types import DeviceLists
from synapse.util import Clock
from tests import unittest
@@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(0, txn.send.call_count) # txn not sent though
self.assertEqual(0, txn.complete.call_count) # or completed
@@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made
self.assertEqual(1, self.recoverer.recover.call_count) # and invoked
@@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4)
event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
self.txn_ctrl.send.assert_called_once_with(
service, [event], [], [], None, None, DeviceLists()
)
def test_send_single_event_with_queue(self):
d = defer.Deferred()
@@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [event], [], [], None, None, DeviceLists()
)
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], None, None
service, [event2, event3], [], [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv1, [srv_1_event], [], [], None, None, DeviceLists()
)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event], [], [], None, None, DeviceLists()
)
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event2], [], [], None, None, DeviceLists()
)
self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], None, None
service, [event_list[0]], [], [], None, None, DeviceLists()
)
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], None, None
service, event_list[1:101], [], [], None, None, DeviceLists()
)
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], None, None
service, event_list[101:], [], [], None, None, DeviceLists()
)
self.assertEqual(3, self.txn_ctrl.send.call_count)
@@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
service, [], event_list, [], None, None, DeviceLists()
)
def test_send_multiple_ephemeral_no_queue(self):
@@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
service, [], event_list, [], None, None, DeviceLists()
)
def test_send_single_ephemeral_with_queue(self):
@@ -345,13 +359,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [], event_list_1, [], None, None, DeviceLists()
)
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, [], None, None
service, [], event_list_2 + event_list_3, [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -365,8 +381,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], None, None
service, [], first_chunk, [], None, None, DeviceLists()
)
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [], second_chunk, [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)

View File

@@ -15,6 +15,8 @@
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from parameterized import parameterized
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
to_device_messages,
_otks,
_fbks,
_device_list_summary,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
@@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
(
service,
_events,
_ephemeral,
to_device_messages,
_otks,
_fbks,
_device_list_summary,
) = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
@@ -627,6 +638,115 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
return appservice
class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
"""
Tests that the ApplicationServicesHandler sends device list updates to application
services correctly.
"""
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Allow us to modify cached feature flags mid-test
self.as_handler = hs.get_application_service_handler()
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
self.put_json = simple_async_mock()
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock(
return_value=self._services
)
# Test across a variety of configuration values
@parameterized.expand(
[
(True, True, True),
(True, False, False),
(False, True, False),
(False, False, False),
]
)
@unittest.override_config({"experimental_features": {"": False}})
def test_application_service_receives_device_list_updates(
self,
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
):
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
Arguments above are populated by parameterized.
Args:
as_should_receive_device_list_updates: Whether we expect the AS to receive the
device list changes.
experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
feature is enabled. This feature must be enabled for device lists to ASs to work.
as_supports_txn_extensions: Whether the application service has explicitly registered
to receive information defined by MSC3202 - which includes device list changes.
"""
# Change whether the experimental feature is enabled or disabled before making
# device list changes
self.as_handler._msc3202_transaction_extensions_enabled = (
experimental_feature_enabled
)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@local_user:.+",
"exclusive": False,
}
],
},
supports_ephemeral=True,
msc3202_transaction_extensions=as_supports_txn_extensions,
# Must be set for Synapse to try pushing data to the AS
hs_token="abcde",
url="some_url",
)
# Register the application service
self._services.append(appservice)
# Register a user on the homeserver
self.local_user = self.register_user("local_user", "password")
self.local_user_token = self.login("local_user", "password")
if as_should_receive_device_list_updates:
# Ensure that the resulting JSON uses the unstable prefix and contains the
# expected users
self.put_json.assert_called_once()
json_body = self.put_json.call_args.kwargs["json_body"]
# Our application service should have received a device list update with
# "local_user" in the "changed" list
device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
self.assertEqual([], device_list_dict["left"])
self.assertEqual([self.local_user], device_list_dict["changed"])
else:
# No device list changes should have been sent out
self.put_json.assert_not_called()
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4

View File

@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
from synapse.types import DeviceLists
from synapse.util import Clock
from tests import unittest
@@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
)
self.assertEqual(txn.id, 1)
@@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9646)
self.assertEqual(txn.events, events)
@@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events)
@@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events)
@@ -476,12 +485,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
)
self.assertEqual(value, 0)
self.assertEqual(value, 1)
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "presence")
)
self.assertEqual(value, 0)
self.assertEqual(value, 1)
def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
self.get_failure(