Compare commits

...

13 Commits

Author SHA1 Message Date
Olivier 'reivilibre
bc8bc5029a Use non-streaming cache invalidations in store methods 2025-07-18 11:55:48 +01:00
Olivier 'reivilibre
33229e88db add ON DELETE CASCADE to FK for rooms 2025-07-18 11:52:44 +01:00
Olivier 'reivilibre
f0e2b2956d fully descriptive names for _txn functions 2025-07-18 10:11:15 +01:00
Olivier 'reivilibre
115978889c Merge branch 'develop' into rei/threads2_sub 2025-07-18 10:06:44 +01:00
Olivier 'reivilibre
fdd1d63722 aside: add warnings about MultiWriterIdGenerator starting at 2 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
e334c9778e Newsfile
Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
b6ca9f93a1 docker workers: register thread_subscriptions worker as worker type 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
398241a986 Add tests for thread subscription endpoints 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
a687cb275b aside: note that EventID only applies to room versions 1 and 2 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
5f5b5645e8 Add thread subscriptions endpoints 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
7d5597cb01 Add thread subscriptions handler 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
f291ce16c6 Add stream backed by thread_subscriptions table 2025-07-15 11:33:58 +01:00
Olivier 'reivilibre
d74d1dfa72 Add thread_subscriptions table 2025-07-11 11:23:37 +01:00
25 changed files with 1522 additions and 3 deletions

View File

@@ -0,0 +1 @@
Add experimental and incomplete support for [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/blob/rei/msc_thread_subscriptions/proposals/4306-thread-subscriptions.md).

View File

@@ -327,6 +327,15 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"thread_subscriptions": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": [
"^/_matrix/client/unstable/io.element.msc4306/.*",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
}
# Templates for sections that may be inserted multiple times in config files
@@ -427,6 +436,7 @@ def add_worker_roles_to_shared_config(
"to_device",
"typing",
"push_rules",
"thread_subscriptions",
}
# Worker-type specific sharding config. Now a single worker can fulfill multiple

View File

@@ -136,6 +136,7 @@ BOOLEAN_COLUMNS = {
"has_known_state",
"is_encrypted",
],
"thread_subscriptions": ["subscribed", "automatic"],
"users": ["shadow_banned", "approved", "locked", "suspended"],
"un_partial_stated_event_stream": ["rejection_status_changed"],
"users_who_share_rooms": ["share_private"],

View File

@@ -104,6 +104,9 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
from synapse.storage.databases.main.thread_subscriptions import (
ThreadSubscriptionsWorkerStore,
)
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -132,6 +135,7 @@ class GenericWorkerStore(
KeyStore,
RoomWorkerStore,
DirectoryWorkerStore,
ThreadSubscriptionsWorkerStore,
PushRulesWorkerStore,
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,

View File

@@ -581,3 +581,7 @@ class ExperimentalConfig(Config):
# MSC4155: Invite filtering
self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False)
# MSC4306: Thread Subscriptions
# (and MSC4308: sliding sync extension for thread subscriptions)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

View File

@@ -174,6 +174,10 @@ class WriterLocations:
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
thread_subscriptions: List[str] = attr.ib(
default=["master"],
converter=_instance_to_list_converter,
)
@attr.s(auto_attribs=True)

View File

@@ -187,6 +187,9 @@ class DeactivateAccountHandler:
# Remove account data (including ignored users and push rules).
await self.store.purge_account_data_for_user(user_id)
# Remove thread subscriptions for the user
await self.store.purge_thread_subscription_settings_for_user(user_id)
# Delete any server-side backup keys
await self.store.bulk_delete_backup_keys_and_versions_for_user(user_id)

View File

@@ -0,0 +1,126 @@
import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import AuthError, NotFoundError
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ThreadSubscriptionsHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
async def get_thread_subscription_settings(
self,
user_id: UserID,
room_id: str,
thread_root_event_id: str,
) -> Optional[ThreadSubscription]:
"""Get thread subscription settings for a specific thread and user.
Checks that the thread root is both a real event and also that it is visible
to the user.
Args:
user_id: The ID of the user
thread_root_event_id: The event ID of the thread root
Returns:
A `ThreadSubscription` containing the active subscription settings or None if not set
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
raise NotFoundError("No such thread root")
except AuthError:
raise NotFoundError("No such thread root")
return await self.store.get_subscription_for_thread(
user_id.to_string(), event.room_id, thread_root_event_id
)
async def subscribe_user_to_thread(
self,
user_id: UserID,
room_id: str,
thread_root_event_id: str,
*,
automatic: bool,
) -> Optional[int]:
"""Sets or updates a user's subscription settings for a specific thread root.
Args:
requester_user_id: The ID of the user whose settings are being updated.
thread_root_event_id: The event ID of the thread root.
automatic: whether the user was subscribed by an automatic decision by
their client.
Returns:
The stream ID for this update, if the update isn't no-opped.
Raises:
NotFoundError if the user cannot access the thread root event, or it isn't
known to this homeserver.
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
raise NotFoundError("No such thread root")
except AuthError:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.subscribe_user_to_thread(
user_id.to_string(),
event.room_id,
thread_root_event_id,
automatic=automatic,
)
async def unsubscribe_user_from_thread(
self, user_id: UserID, room_id: str, thread_root_event_id: str
) -> Optional[int]:
"""Clears a user's subscription settings for a specific thread root.
Args:
requester_user_id: The ID of the user whose settings are being updated.
thread_root_event_id: The event ID of the thread root.
Returns:
The stream ID for this update, if the update isn't no-opped.
Raises:
NotFoundError if the user cannot access the thread root event, or it isn't
known to this homeserver.
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
raise NotFoundError("No such thread root")
except AuthError:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.unsubscribe_user_from_thread(
user_id.to_string(),
event.room_id,
thread_root_event_id,
)

View File

@@ -72,7 +72,10 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.replication.tcp.streams._base import (
DeviceListsStream,
ThreadSubscriptionsStream,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -186,6 +189,15 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, ThreadSubscriptionsStream):
if (
hs.get_instance_name()
in hs.config.worker.writers.thread_subscriptions
):
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, DeviceListsStream):
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
self._streams_to_replicate.append(stream)

View File

@@ -41,6 +41,7 @@ from synapse.replication.tcp.streams._base import (
PushRulesStream,
ReceiptsStream,
Stream,
ThreadSubscriptionsStream,
ToDeviceStream,
TypingStream,
)
@@ -67,6 +68,7 @@ STREAMS_MAP = {
ToDeviceStream,
FederationStream,
AccountDataStream,
ThreadSubscriptionsStream,
UnPartialStatedRoomStream,
UnPartialStatedEventStream,
)
@@ -86,6 +88,7 @@ __all__ = [
"DeviceListsStream",
"ToDeviceStream",
"AccountDataStream",
"ThreadSubscriptionsStream",
"UnPartialStatedRoomStream",
"UnPartialStatedEventStream",
]

View File

@@ -723,3 +723,47 @@ class AccountDataStream(_StreamFromIdGen):
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
return updates, to_token, limited
class ThreadSubscriptionsStream(_StreamFromIdGen):
"""A thread subscription was changed."""
@attr.s(slots=True, auto_attribs=True)
class ThreadSubscriptionsStreamRow:
"""Stream to inform workers about changes to thread subscriptions."""
user_id: str
room_id: str
event_id: str # The event ID of the thread root
NAME = "thread_subscriptions"
ROW_TYPE = ThreadSubscriptionsStreamRow
def __init__(self, hs: Any):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._update_function,
self.store._thread_subscriptions_id_gen,
)
async def _update_function(
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
updates = await self.store.get_updated_thread_subscriptions(
from_token, to_token, limit
)
rows = [
(
stream_id,
# These are the args to `ThreadSubscriptionsStreamRow`
(user_id, room_id, event_id),
)
for stream_id, user_id, room_id, event_id in updates
]
logger.error("TS %d->%d %r", from_token, to_token, rows)
if not rows:
return [], to_token, False
return rows, rows[-1][0], len(updates) == limit

View File

@@ -0,0 +1,98 @@
from http import HTTPStatus
from typing import Tuple
from synapse._pydantic_compat import StrictBool
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomID
from synapse.types.rest import RequestBodyModel
class ThreadSubscriptionsRestServlet(RestServlet):
PATTERNS = client_patterns(
"/io.element.msc4306/rooms/(?P<room_id>[^/]*)/thread/(?P<thread_root_id>[^/]*)/subscription$",
unstable=True,
releases=(),
)
CATEGORY = "Thread Subscriptions requests (unstable)"
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastores().main
self.handler = hs.get_thread_subscriptions_handler()
class PutBody(RequestBodyModel):
automatic: StrictBool
async def on_GET(
self, request: SynapseRequest, room_id: str, thread_root_id: str
) -> Tuple[int, JsonDict]:
RoomID.from_string(room_id)
if not thread_root_id.startswith("$"):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
subscription = await self.handler.get_thread_subscription_settings(
requester.user,
room_id,
thread_root_id,
)
if subscription is None:
raise NotFoundError("Not subscribed.")
return HTTPStatus.OK, {"automatic": subscription.automatic}
async def on_PUT(
self, request: SynapseRequest, room_id: str, thread_root_id: str
) -> Tuple[int, JsonDict]:
RoomID.from_string(room_id)
if not thread_root_id.startswith("$"):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.handler.subscribe_user_to_thread(
requester.user,
room_id,
thread_root_id,
automatic=body.automatic,
)
return HTTPStatus.OK, {}
async def on_DELETE(
self, request: SynapseRequest, room_id: str, thread_root_id: str
) -> Tuple[int, JsonDict]:
RoomID.from_string(room_id)
if not thread_root_id.startswith("$"):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
await self.handler.unsubscribe_user_from_thread(
requester.user,
room_id,
thread_root_id,
)
return HTTPStatus.OK, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc4306_enabled:
ThreadSubscriptionsRestServlet(hs).register(http_server)

View File

@@ -117,6 +117,7 @@ from synapse.handlers.sliding_sync import SlidingSyncHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.thread_subscriptions import ThreadSubscriptionsHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.worker_lock import WorkerLocksHandler
@@ -789,6 +790,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_timestamp_lookup_handler(self) -> TimestampLookupHandler:
return TimestampLookupHandler(self)
@cache_in_self
def get_thread_subscriptions_handler(self) -> ThreadSubscriptionsHandler:
return ThreadSubscriptionsHandler(self)
@cache_in_self
def get_registration_handler(self) -> RegistrationHandler:
return RegistrationHandler(self)

View File

@@ -19,7 +19,6 @@
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
@@ -35,6 +34,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.databases.main.thread_subscriptions import (
ThreadSubscriptionsWorkerStore,
)
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
from synapse.types import get_domain_from_id
@@ -141,6 +143,7 @@ class DataStore(
SearchStore,
TagsStore,
AccountDataStore,
ThreadSubscriptionsWorkerStore,
PushRulesWorkerStore,
StreamWorkerStore,
OpenIdStore,

View File

@@ -2986,6 +2986,10 @@ class PersistEventsStore:
# Upsert into the threads table, but only overwrite the value if the
# new event is of a later topological order OR if the topological
# ordering is equal, but the stream ordering is later.
# (Note by definition that the stream ordering will always be later
# unless this is a backfilled event [= negative stream ordering]
# because we are only persisting this event now and stream_orderings
# are strictly monotonically increasing)
sql = """
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
VALUES (?, ?, ?, ?, ?)

View File

@@ -0,0 +1,382 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)
import attr
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscription:
automatic: bool
"""
whether the subscription was made automatically (as opposed to by manual
action from the user)
"""
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._can_write_to_thread_subscriptions = (
self._instance_name in hs.config.worker.writers.thread_subscriptions
)
self._thread_subscriptions_id_gen: MultiWriterIdGenerator = (
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="thread_subscriptions",
instance_name=self._instance_name,
tables=[
("thread_subscriptions", "instance_name", "stream_id"),
],
sequence_name="thread_subscriptions_sequence",
writers=hs.config.worker.writers.thread_subscriptions,
)
)
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ThreadSubscriptionsStream.NAME:
for row in rows:
self.get_subscription_for_thread.invalidate(
(row.user_id, row.room_id, row.event_id)
)
super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ThreadSubscriptionsStream.NAME:
self._thread_subscriptions_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
async def subscribe_user_to_thread(
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
) -> Optional[int]:
"""Updates a user's subscription settings for a specific thread root.
If no change would be made to the subscription, does not produce any database change.
Args:
user_id: The ID of the user whose settings are being updated.
room_id: The ID of the room the thread root belongs to.
thread_root_event_id: The event ID of the thread root.
automatic: Whether the subscription was performed automatically by the user's client.
Only `False` will overwrite an existing value of automatic for a subscription row.
Returns:
The stream ID for this update, if the update isn't no-opped.
"""
assert self._can_write_to_thread_subscriptions
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
already_automatic = self.db_pool.simple_select_one_onecol_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
},
retcol="automatic",
allow_none=True,
)
if already_automatic is None:
already_subscribed = False
already_automatic = True
else:
already_subscribed = True
# convert int (SQLite bool) to Python bool
already_automatic = bool(already_automatic)
if already_subscribed and already_automatic == automatic:
# there is nothing we need to do here
return None
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
values: Dict[str, Optional[Union[bool, int, str]]] = {
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": already_automatic and automatic,
}
self.db_pool.simple_upsert_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
values=values,
)
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
return stream_id
return await self.db_pool.runInteraction(
"subscribe_user_to_thread", _subscribe_user_to_thread_txn
)
async def unsubscribe_user_from_thread(
self, user_id: str, room_id: str, thread_root_event_id: str
) -> Optional[int]:
"""Unsubscribes a user from a thread.
If no change would be made to the subscription, does not produce any database change.
Args:
user_id: The ID of the user whose settings are being updated.
room_id: The ID of the room the thread root belongs to.
thread_root_event_id: The event ID of the thread root.
Returns:
The stream ID for this update, if the update isn't no-opped.
"""
assert self._can_write_to_thread_subscriptions
def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]:
already_subscribed = self.db_pool.simple_select_one_onecol_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
retcol="subscribed",
allow_none=True,
)
if already_subscribed is None or already_subscribed is False:
# there is nothing we need to do here
return None
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
},
updatevalues={
"subscribed": False,
"stream_id": stream_id,
"instance_name": self._instance_name,
},
)
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
return stream_id
return await self.db_pool.runInteraction(
"unsubscribe_user_from_thread", _unsubscribe_user_from_thread_txn
)
async def purge_thread_subscription_settings_for_user(self, user_id: str) -> None:
"""
Purge all subscriptions for the user.
The fact that subscriptions have been purged will not be streamed;
all stream rows for the user will in fact be removed.
This is intended only for dealing with user deactivation.
"""
def _purge_thread_subscription_settings_for_user_txn(
txn: LoggingTransaction,
) -> None:
self.db_pool.simple_delete_txn(
txn,
table="thread_subscriptions",
keyvalues={"user_id": user_id},
)
self._invalidate_cache_and_stream(
txn, self.get_subscription_for_thread, (user_id,)
)
await self.db_pool.runInteraction(
desc="purge_thread_subscription_settings_for_user",
func=_purge_thread_subscription_settings_for_user_txn,
)
@cached(tree=True)
async def get_subscription_for_thread(
self, user_id: str, room_id: str, thread_root_event_id: str
) -> Optional[ThreadSubscription]:
"""Get the thread subscription for a specific thread and user.
Args:
user_id: The ID of the user
room_id: The ID of the room
thread_root_event_id: The event ID of the thread root
Returns:
A `ThreadSubscription` dataclass if there is a subscription,
or `None` if there is no subscription.
If there is a row in the table but `subscribed` is `False`,
behaves the same as if there was no row at all and returns `None`.
"""
row = await self.db_pool.simple_select_one(
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"event_id": thread_root_event_id,
"subscribed": True,
},
retcols=("automatic",),
allow_none=True,
desc="get_subscription_for_thread",
)
if row is None:
return None
(automatic_rawbool,) = row
# convert SQLite integer booleans into real booleans
automatic = bool(automatic_rawbool)
return ThreadSubscription(automatic=automatic)
def get_max_thread_subscriptions_stream_id(self) -> int:
"""Get the current maximum stream_id for thread subscriptions.
Returns:
The maximum stream_id
"""
return self._thread_subscriptions_id_gen.get_current_token()
async def get_updated_thread_subscriptions(
self, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get updates to thread subscriptions between two stream IDs.
Args:
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
Returns:
list of (stream_id, user_id, room_id, thread_root_id) tuples
"""
def get_updated_thread_subscriptions_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str]]:
sql = """
SELECT stream_id, user_id, room_id, event_id
FROM thread_subscriptions
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (from_id, to_id, limit))
return cast(List[Tuple[int, str, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions",
get_updated_thread_subscriptions_txn,
)
async def get_updated_thread_subscriptions_for_user(
self, user_id: str, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.
Args:
user_id: The ID of the user
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
Returns:
A list of (stream_id, room_id, thread_root_event_id) tuples.
"""
def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
sql = """
SELECT stream_id, room_id, event_id
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, from_id, to_id, limit))
return [(row[0], row[1], row[2]) for row in txn]
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions_for_user",
get_updated_thread_subscriptions_for_user_txn,
)

View File

@@ -0,0 +1,59 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Introduce a table for tracking users' subscriptions to threads.
CREATE TABLE thread_subscriptions (
stream_id INTEGER NOT NULL PRIMARY KEY,
instance_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
subscribed BOOLEAN NOT NULL,
automatic BOOLEAN NOT NULL,
CONSTRAINT thread_subscriptions_fk_users
FOREIGN KEY (user_id)
REFERENCES users(name),
CONSTRAINT thread_subscriptions_fk_rooms
FOREIGN KEY (room_id)
-- When we delete a room, we should already have deleted all the events in that room
-- and so there shouldn't be any subscriptions left in that room.
-- So the `ON DELETE CASCADE` should be optional, but included anyway for good measure.
REFERENCES rooms(room_id) ON DELETE CASCADE,
CONSTRAINT thread_subscriptions_fk_events
FOREIGN KEY (event_id)
REFERENCES events(event_id) ON DELETE CASCADE,
-- This order provides a useful index for:
-- 1. foreign key constraint on (room_id)
-- 2. foreign key constraint on (room_id, event_id)
-- 3. finding the user's settings for a specific thread (as well as enforcing uniqueness)
UNIQUE (room_id, event_id, user_id)
);
-- this provides a useful index for finding a user's own rules,
-- potentially scoped to a single room
CREATE INDEX thread_subscriptions_user_room ON thread_subscriptions (user_id, room_id);
-- this provides a useful way for clients to efficiently find new changes to
-- their subscriptions.
-- (This is necessary to sync subscriptions between multiple devices.)
CREATE INDEX thread_subscriptions_by_user ON thread_subscriptions (user_id, stream_id);
-- this provides a useful index for deleting the subscriptions when the underlying
-- events are removed. This also covers the foreign key constraint on `events`.
CREATE INDEX thread_subscriptions_by_event ON thread_subscriptions (event_id);

View File

@@ -0,0 +1,19 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE SEQUENCE thread_subscriptions_sequence
-- Synapse streams start at 2, because the default position is 1
-- so any item inserted at position 1 is ignored.
-- This is also what existing streams do, except they use `setval(..., 1)`
-- which is semantically the same except less obvious.
START WITH 2;

View File

@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
COMMENT ON TABLE thread_subscriptions IS 'Tracks local users that subscribe to threads';
COMMENT ON COLUMN thread_subscriptions.subscribed IS 'Whether the user is subscribed to the thread or not. We track unsubscribed threads because we need to stream the subscription change to the client.';
COMMENT ON COLUMN thread_subscriptions.automatic IS 'True if the user was subscribed to the thread automatically by their client, or false if the client manually requested the subscription.';

View File

@@ -0,0 +1,24 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
COMMENT ON COLUMN threads.latest_event_id IS
'the ID of the event that is latest, ordered by (topological_ordering, stream_ordering)';
COMMENT ON COLUMN threads.topological_ordering IS
$$the topological ordering of the thread''s LATEST event.
Used as the primary way of ordering threads by recency in a room.$$;
COMMENT ON COLUMN threads.stream_ordering IS
$$the stream ordering of the thread's LATEST event.
Used as a tie-breaker for ordering threads by recency in a room, when the topological order is a tie.
Also used for recency ordering in sliding sync.$$;

View File

@@ -184,6 +184,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
Note: Only works with Postgres.
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
to have been 'seen as persisted'.
Unclear if this extant behaviour is desirable for some reason.
When creating a new sequence for a new stream,
it will be necessary to use `START WITH 2`.
Args:
db_conn
db
@@ -269,6 +275,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._known_persisted_positions: List[int] = []
# The maximum stream ID that we have seen been allocated across any writer.
# Since this defaults to 1, this means that ID 1 is assumed to have already
# been 'seen'. In other words, multi-writer streams start at 2.
# Unclear if this is desirable behaviour.
self._max_seen_allocated_stream_id = 1
# The maximum position of the local instance. This can be higher than

View File

@@ -362,7 +362,8 @@ class RoomID(DomainSpecificString):
@attr.s(slots=True, frozen=True, repr=False)
class EventID(DomainSpecificString):
"""Structure representing an event id."""
"""Structure representing an event ID which is namespaced to a homeserver.
Room versions 3 and above are not supported by this grammar."""
SIGIL = "$"

View File

@@ -0,0 +1,157 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
ThreadSubscriptionsStream,
)
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Postgres
def f(txn: LoggingTransaction) -> None:
txn.execute(
"""
ALTER TABLE thread_subscriptions
DROP CONSTRAINT thread_subscriptions_fk_users,
DROP CONSTRAINT thread_subscriptions_fk_rooms,
DROP CONSTRAINT thread_subscriptions_fk_events;
""",
)
self.get_success(
self.hs.get_datastores().main.db_pool.runInteraction(
"disable_foreign_keys", f
)
)
def test_thread_subscription_updates(self) -> None:
"""Test replication with thread subscription updates"""
store = self.hs.get_datastores().main
# Create thread subscription updates
updates = []
room_id = "!test_room:example.com"
# Generate several thread subscription updates
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
thread_root_id = f"$thread_{i}:example.com"
self.get_success(
store.subscribe_user_to_thread(
"@test_user:example.org",
room_id,
thread_root_id,
automatic=True,
)
)
updates.append(thread_root_id)
# Also add one in a different room
other_room_id = "!other_room:example.com"
other_thread_root_id = "$other_thread:example.com"
self.get_success(
store.subscribe_user_to_thread(
"@test_user:example.org",
other_room_id,
other_thread_root_id,
automatic=False,
)
)
# Not yet connected: no rows should yet have been received
self.assertEqual([], self.test_handler.received_rdata_rows)
# Now reconnect to pull the updates
self.reconnect()
self.replicate()
# We should have received all the expected rows in the right order
# Filter the updates to only include thread subscription changes
received_rows = [
upd
for upd in self.test_handler.received_rdata_rows
if upd[0] == ThreadSubscriptionsStream.NAME
]
# Verify all the thread subscription updates
for thread_id in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, "@test_user:example.org")
self.assertEqual(row.room_id, room_id)
self.assertEqual(row.event_id, thread_id)
# Verify the last update in the different room
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, "@test_user:example.org")
self.assertEqual(row.room_id, other_room_id)
self.assertEqual(row.event_id, other_thread_root_id)
self.assertEqual([], received_rows)
def test_multiple_users_thread_subscription_updates(self) -> None:
"""Test replication with thread subscription updates for multiple users"""
store = self.hs.get_datastores().main
room_id = "!test_room:example.com"
thread_root_id = "$thread_root:example.com"
# Create updates for multiple users
users = ["@user1:example.com", "@user2:example.com", "@user3:example.com"]
for user_id in users:
self.get_success(
store.subscribe_user_to_thread(
user_id, room_id, thread_root_id, automatic=True
)
)
# Check no rows have been received yet
self.replicate()
self.assertEqual([], self.test_handler.received_rdata_rows)
# Not yet connected: no rows should yet have been received
self.reconnect()
self.replicate()
# We should have received all the expected rows
# Filter the updates to only include thread subscription changes
received_rows = [
upd
for upd in self.test_handler.received_rdata_rows
if upd[0] == ThreadSubscriptionsStream.NAME
]
# Should have one update per user
self.assertEqual(len(received_rows), len(users))
# Verify all updates
for i, user_id in enumerate(users):
(stream_name, token, row) = received_rows[i]
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, user_id)
self.assertEqual(row.room_id, room_id)
self.assertEqual(row.event_id, thread_root_id)

View File

@@ -0,0 +1,256 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, profile, room, thread_subscriptions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
PREFIX = "/_matrix/client/unstable/io.element.msc4306/rooms"
class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
room.register_servlets,
thread_subscriptions.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4306_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
self.other_user_id = self.register_user("other_user", "password")
self.other_token = self.login("other_user", "password")
# Create a room and send a message to use as a thread root
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
self.root_event_id = response["event_id"]
# Send a message in the thread
self.helper.send_event(
room_id=self.room_id,
type="m.room.message",
content={
"body": "Thread message",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": self.root_event_id,
},
},
tok=self.token,
)
def test_get_thread_subscription_unsubscribed(self) -> None:
"""Test retrieving thread subscription when not subscribed."""
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_get_thread_subscription_nonexistent_thread(self) -> None:
"""Test retrieving subscription settings for a nonexistent thread."""
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_get_thread_subscription_no_access(self) -> None:
"""Test that a user can't get thread subscription for a thread they can't access."""
self.register_user("no_access", "password")
no_access_token = self.login("no_access", "password")
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=no_access_token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_subscribe_manual_then_automatic(self) -> None:
"""Test subscribing to a thread, first a manual subscription then an automatic subscription.
The manual subscription wins over the automatic one."""
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": False,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the subscription was saved
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": False})
# Now also register an automatic subscription; it should not
# override the manual subscription
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": False})
def test_subscribe_automatic_then_manual(self) -> None:
"""Test subscribing to a thread, first an automatic subscription then a manual subscription.
The manual subscription wins over the automatic one."""
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the subscription was saved
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})
# Now also register a manual subscription
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": False},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": False})
def test_unsubscribe(self) -> None:
"""Test subscribing to a thread, then unsubscribing."""
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the subscription was saved
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})
# Now also register a manual subscription
channel = self.make_request(
"DELETE",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_set_thread_subscription_nonexistent_thread(self) -> None:
"""Test setting subscription settings for a nonexistent thread."""
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
{"automatic": True},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_set_thread_subscription_no_access(self) -> None:
"""Test that a user can't set thread subscription for a thread they can't access."""
self.register_user("no_access2", "password")
no_access_token = self.login("no_access2", "password")
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
access_token=no_access_token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_invalid_body(self) -> None:
"""Test that sending invalid subscription settings is rejected."""
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-boolean `automatic`
{"automatic": "true"},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)

View File

@@ -0,0 +1,272 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines.sqlite import Sqlite3Engine
from synapse.util import Clock
from tests import unittest
class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = "@user:test"
self.room_id = "!room:test"
self.thread_root_id = "$thread_root:test"
self.other_thread_root_id = "$other_thread_root:test"
# Disable foreign key checks for testing
# This allows us to insert test data without having to create actual events
db_pool = self.store.db_pool
if isinstance(db_pool.engine, Sqlite3Engine):
self.get_success(
db_pool.execute("disable_foreign_keys", "PRAGMA foreign_keys = OFF;")
)
else:
# Postgres
def f(txn: LoggingTransaction) -> None:
txn.execute(
"""
ALTER TABLE thread_subscriptions
DROP CONSTRAINT thread_subscriptions_fk_users,
DROP CONSTRAINT thread_subscriptions_fk_rooms,
DROP CONSTRAINT thread_subscriptions_fk_events;
""",
)
self.get_success(db_pool.runInteraction("disable_foreign_keys", f))
# Create rooms and events in the db to satisfy foreign key constraints
self.get_success(db_pool.simple_insert("rooms", {"room_id": self.room_id}))
self.get_success(
db_pool.simple_insert(
"events",
{
"event_id": self.thread_root_id,
"room_id": self.room_id,
"topological_ordering": 1,
"stream_ordering": 1,
"type": "m.room.message",
"depth": 1,
"processed": True,
"outlier": False,
},
)
)
self.get_success(
db_pool.simple_insert(
"events",
{
"event_id": self.other_thread_root_id,
"room_id": self.room_id,
"topological_ordering": 2,
"stream_ordering": 2,
"type": "m.room.message",
"depth": 2,
"processed": True,
"outlier": False,
},
)
)
# Create the user
self.get_success(
db_pool.simple_insert("users", {"name": self.user_id, "is_guest": 0})
)
def _subscribe(
self,
thread_root_id: str,
*,
automatic: bool,
room_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Optional[int]:
if user_id is None:
user_id = self.user_id
if room_id is None:
room_id = self.room_id
return self.get_success(
self.store.subscribe_user_to_thread(
user_id,
room_id,
thread_root_id,
automatic=automatic,
)
)
def _unsubscribe(
self,
thread_root_id: str,
room_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Optional[int]:
if user_id is None:
user_id = self.user_id
if room_id is None:
room_id = self.room_id
return self.get_success(
self.store.unsubscribe_user_from_thread(
user_id,
room_id,
thread_root_id,
)
)
def test_set_and_get_thread_subscription(self) -> None:
"""Test basic setting and getting of thread subscriptions."""
# Initial state: no subscription
subscription = self.get_success(
self.store.get_subscription_for_thread(
self.user_id, self.room_id, self.thread_root_id
)
)
self.assertIsNone(subscription)
# Subscribe
self._subscribe(
self.thread_root_id,
automatic=True,
)
# Assert subscription went through
subscription = self.get_success(
self.store.get_subscription_for_thread(
self.user_id, self.room_id, self.thread_root_id
)
)
self.assertIsNotNone(subscription)
self.assertTrue(subscription.automatic) # type: ignore
# Now make it a manual subscription
self._subscribe(
self.thread_root_id,
automatic=False,
)
# Assert the manual subscription overrode the automatic one
subscription = self.get_success(
self.store.get_subscription_for_thread(
self.user_id, self.room_id, self.thread_root_id
)
)
self.assertFalse(subscription.automatic) # type: ignore
def test_purge_thread_subscriptions_for_user(self) -> None:
"""Test purging all thread subscription settings for a user."""
# Set subscription settings for multiple threads
self._subscribe(self.thread_root_id, automatic=True)
self._subscribe(self.other_thread_root_id, automatic=False)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
limit=50,
)
)
min_id = min(id for (id, _, _) in subscriptions)
self.assertEqual(
subscriptions,
[
(min_id, self.room_id, self.thread_root_id),
(min_id + 1, self.room_id, self.other_thread_root_id),
],
)
# Purge all settings for the user
self.get_success(
self.store.purge_thread_subscription_settings_for_user(self.user_id)
)
# Check user has no subscriptions
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
limit=50,
)
)
self.assertEqual(subscriptions, [])
def test_get_updated_thread_subscriptions(self) -> None:
"""Test getting updated thread subscriptions since a stream ID."""
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
assert stream_id1 is not None
assert stream_id2 is not None
# Get updates since initial ID (should include both changes)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
)
self.assertEqual(len(updates), 2)
# Get updates since first change (should include only the second change)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
)
self.assertEqual(
updates,
[(stream_id2, self.user_id, self.room_id, self.other_thread_root_id)],
)
def test_get_updated_thread_subscriptions_for_user(self) -> None:
"""Test getting updated thread subscriptions for a specific user."""
other_user_id = "@other_user:test"
# Set thread subscription for main user
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
assert stream_id1 is not None
# Set thread subscription for other user
stream_id2 = self._subscribe(
self.other_thread_root_id,
automatic=True,
user_id=other_user_id,
)
assert stream_id2 is not None
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id, 0, stream_id2, 10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
other_user_id, 0, max(stream_id1, stream_id2), 10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
)