Compare commits

...

23 Commits

Author SHA1 Message Date
Mathieu Velten
86641cb3a8 Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api 2023-10-10 16:55:30 +02:00
Mathieu Velten
99fefd5501 Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api 2023-10-09 15:19:19 +02:00
Mathieu Velten
1f4b960e62 argggggghhhh 2023-06-01 15:59:24 +02:00
Mathieu Velten
9c73f259eb fix postgres 2023-05-31 16:59:28 +02:00
Mathieu Velten
86ca31e705 Improvments 2023-05-31 16:03:57 +02:00
Mathieu Velten
a5c50548b0 Simplify code 2023-05-31 12:07:04 +02:00
Mathieu Velten
dcc49cd1ae More tests, less bugs 2023-05-30 17:41:43 +02:00
Mathieu Velten
3194933c1b Fix forwards pagination 2023-05-26 18:03:04 +02:00
Mathieu Velten
50d75a311b Fix test 2023-05-23 11:59:31 +02:00
Mathieu Velten
7709a99e6f Fix order 2023-05-22 18:32:32 +02:00
Mathieu Velten
b64aa1a3bb use attrs class 2023-05-22 17:26:50 +02:00
Mathieu Velten
3347725cc1 isort 2023-05-19 17:52:38 +02:00
Mathieu Velten
bd66f4384d types 2023-05-19 17:44:42 +02:00
Mathieu Velten
75f9e56c77 Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api 2023-05-19 17:35:54 +02:00
Mathieu Velten
a0ea6c1cba lint 2023-05-19 17:35:29 +02:00
Mathieu Velten
71885068e5 Add changelog + stuffs 2023-05-19 17:33:34 +02:00
Mathieu Velten
2362ef10a3 Fix stuffs 2023-05-19 16:56:26 +02:00
Mathieu Velten
74dbcaaab2 Fix stuffs 2023-05-19 16:46:02 +02:00
Mathieu Velten
e01ea0edc0 Add test 2023-05-19 12:29:05 +02:00
Mathieu Velten
79923666c5 Various fixups 2023-05-19 11:23:55 +02:00
Mathieu Velten
738b372379 Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api 2023-05-16 17:41:39 +02:00
Andrew Morgan
5c1e9f24da wip: call the public room callback 2023-05-02 15:23:32 +01:00
Andrew Morgan
2436153e8f Add a new public rooms callback class, a new fetch_public_rooms callback
fetch_public_rooms is a module API callback intended to be used when a
request for the homeserver's public rooms list comes in via either the
CS or SS API. Modules can return an ordered array of public rooms that
they would like to inject into the list supplied by the homeserver.

This can be useful for exposing known rooms that users on the
homeserver have not joined yet, and the property of mixing with the
normal public rooms list is desirable versus the solution of creating
a new third-party network type to load your rooms under.
2023-05-02 15:23:32 +01:00
12 changed files with 679 additions and 88 deletions

View File

@@ -0,0 +1 @@
Allow modules to provide local /publicRooms results.

View File

@@ -149,7 +149,10 @@ class PublicRoomList(BaseFederationServlet):
limit = None
data = await self.handler.get_local_public_room_list(
limit, since_token, network_tuple=network_tuple, from_federation=True
limit,
since_token,
network_tuple=network_tuple,
from_remote_server_name=origin,
)
return 200, data
@@ -190,7 +193,7 @@ class PublicRoomList(BaseFederationServlet):
since_token=since_token,
search_filter=search_filter,
network_tuple=network_tuple,
from_federation=True,
from_remote_server_name=origin,
)
return 200, data

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
import msgpack
@@ -33,7 +33,8 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.types import JsonDict, JsonMapping, PublicRoom, ThirdPartyInstanceID
from synapse.util import filter_none
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
@@ -60,6 +61,7 @@ class RoomListHandler:
self.remote_response_cache: ResponseCache[
Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
self._module_api_callbacks = hs.get_module_api_callbacks().public_rooms
async def get_local_public_room_list(
self,
@@ -67,7 +69,8 @@ class RoomListHandler:
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
from_client_mxid: Optional[str] = None,
from_remote_server_name: Optional[str] = None,
) -> JsonDict:
"""Generate a local public room list.
@@ -75,14 +78,20 @@ class RoomListHandler:
party network. A client can ask for a specific list or to return all.
Args:
limit
since_token
search_filter
limit: The maximum number of rooms to return, or None to return all rooms.
since_token: A pagination token, or None to return the head of the public
rooms list.
search_filter: An optional dictionary with the following keys:
* generic_search_term: A string to search for in room ...
* room_types: A list to filter returned rooms by their type. If None or
an empty list is passed, rooms will not be filtered by type.
network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
from_federation: true iff the request comes from the federation API
from_client_mxid: A user's MXID if this request came from a registered user.
from_remote_server_name: A remote homeserver's server name, if this
request came from the federation API.
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -105,7 +114,8 @@ class RoomListHandler:
since_token,
search_filter,
network_tuple=network_tuple,
from_federation=from_federation,
from_client_mxid=from_client_mxid,
from_remote_server_name=from_remote_server_name,
)
key = (limit, since_token, network_tuple)
@@ -115,7 +125,8 @@ class RoomListHandler:
limit,
since_token,
network_tuple=network_tuple,
from_federation=from_federation,
from_client_mxid=from_client_mxid,
from_remote_server_name=from_remote_server_name,
)
async def _get_public_room_list(
@@ -124,7 +135,8 @@ class RoomListHandler:
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
from_client_mxid: Optional[str] = None,
from_remote_server_name: Optional[str] = None,
) -> JsonDict:
"""Generate a public room list.
Args:
@@ -135,65 +147,106 @@ class RoomListHandler:
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
from_federation: Whether this request originated from a
federating server or a client. Used for room filtering.
from_client_mxid: A user's MXID if this request came from a registered user.
from_remote_server_name: A remote homeserver's server name, if this
request came from the federation API.
"""
# Pagination tokens work by storing the room ID sent in the last batch,
# plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
forwards = True
last_joined_members = None
last_room_id = None
last_module_index = None
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
bounds: Optional[Tuple[int, str]] = (
batch_token.last_joined_members,
batch_token.last_room_id,
)
print(batch_token)
forwards = batch_token.direction_is_forward
has_batch_token = True
else:
bounds = None
last_joined_members = batch_token.last_joined_members
last_room_id = batch_token.last_room_id
last_module_index = batch_token.last_module_index
forwards = True
has_batch_token = False
# we request one more than wanted to see if there are more pages to come
# We request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
results = await self.store.get_largest_public_rooms(
# We bucket results per joined members number since we want to keep order
# per joined members number
num_joined_members_buckets: Dict[int, List[PublicRoom]] = {}
room_ids_to_module_index: Dict[str, int] = {}
local_public_rooms = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
bounds=bounds,
bounds=(
last_joined_members,
last_room_id if last_module_index is None else None,
),
forwards=forwards,
ignore_non_federatable=from_federation,
ignore_non_federatable=bool(from_remote_server_name),
)
def build_room_entry(room: JsonDict) -> JsonDict:
entry = {
"room_id": room["room_id"],
"name": room["name"],
"topic": room["topic"],
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
"room_type": room["room_type"],
}
for room in local_public_rooms:
num_joined_members_buckets.setdefault(room.num_joined_members, []).append(
room
)
# Filter out Nones rather omit the field altogether
return {k: v for k, v in entry.items() if v is not None}
nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks)
results = [build_room_entry(r) for r in results]
module_range = range(nb_modules)
# if not forwards:
# module_range = reversed(module_range)
for module_index in module_range:
fetch_public_rooms = (
self._module_api_callbacks.fetch_public_rooms_callbacks[module_index]
)
# Ask each module for a list of public rooms given the last_joined_members
# value from the since token and the probing limit
# last_joined_members needs to be reduce by one if this module has already
# given its result for last_joined_members
module_last_joined_members = last_joined_members
if module_last_joined_members is not None and last_module_index is not None:
if forwards and module_index < last_module_index:
module_last_joined_members = module_last_joined_members - 1
# if not forwards and module_index > last_module_index:
# module_last_joined_members = module_last_joined_members - 1
module_public_rooms = await fetch_public_rooms(
network_tuple,
search_filter,
probing_limit,
(
module_last_joined_members,
last_room_id if last_module_index == module_index else None,
),
forwards,
)
for room in module_public_rooms:
num_joined_members_buckets.setdefault(
room.num_joined_members, []
).append(room)
room_ids_to_module_index[room.room_id] = module_index
nums_joined_members = list(num_joined_members_buckets.keys())
nums_joined_members.sort(reverse=forwards)
results = []
for num_joined_members in nums_joined_members:
rooms = num_joined_members_buckets[num_joined_members]
# if not forwards:
# rooms.reverse()
results += rooms
print([(r.room_id, r.num_joined_members) for r in results])
response: JsonDict = {}
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
if limit is not None and probing_limit is not None:
more_to_come = num_results >= probing_limit
# Depending on direction we trim either the front or back.
if forwards:
@@ -203,46 +256,60 @@ class RoomListHandler:
else:
more_to_come = False
print([(r.room_id, r.num_joined_members) for r in results])
if num_results > 0:
final_entry = results[-1]
initial_entry = results[0]
if forwards:
if has_batch_token:
if since_token is not None:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
last_joined_members=initial_entry.num_joined_members,
last_room_id=initial_entry.room_id,
direction_is_forward=False,
last_module_index=room_ids_to_module_index.get(
initial_entry.room_id
),
).to_token()
if more_to_come:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
last_joined_members=final_entry.num_joined_members,
last_room_id=final_entry.room_id,
direction_is_forward=True,
last_module_index=room_ids_to_module_index.get(
final_entry.room_id
),
).to_token()
else:
if has_batch_token:
if since_token is not None:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
last_joined_members=final_entry.num_joined_members,
last_room_id=final_entry.room_id,
direction_is_forward=True,
last_module_index=room_ids_to_module_index.get(
final_entry.room_id
),
).to_token()
if more_to_come:
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
last_joined_members=initial_entry.num_joined_members,
last_room_id=initial_entry.room_id,
direction_is_forward=False,
last_module_index=room_ids_to_module_index.get(
initial_entry.room_id
),
).to_token()
response["chunk"] = results
response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results]
response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple,
ignore_non_federatable=from_federation,
ignore_non_federatable=bool(from_remote_server_name),
search_filter=search_filter,
)
@@ -484,11 +551,13 @@ class RoomListNextBatch:
last_joined_members: int # The count to get rooms after/before
last_room_id: str # The room_id to get rooms after/before
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
last_module_index: Optional[int] = None
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
"direction_is_forward": "d",
"last_module_index": "i",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@@ -501,6 +570,7 @@ class RoomListNextBatch:
)
def to_token(self) -> str:
# print(self)
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}

View File

@@ -79,6 +79,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.module_api.callbacks.public_rooms_callbacks import (
FETCH_PUBLIC_ROOMS_CALLBACK,
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
@@ -170,6 +173,7 @@ __all__ = [
"DirectServeJsonResource",
"ModuleApi",
"PRESENCE_ALL_USERS",
"PublicRoomChunk",
"LoginResponse",
"JsonDict",
"JsonMapping",
@@ -472,6 +476,19 @@ class ModuleApi:
on_account_data_updated=on_account_data_updated,
)
def register_public_rooms_callbacks(
self,
*,
fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
) -> None:
"""Registers callback functions related to the public room directory.
Added in Synapse v1.80.0
"""
return self._callbacks.public_rooms.register_callbacks(
fetch_public_rooms=fetch_public_rooms,
)
def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.

View File

@@ -27,9 +27,12 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
ThirdPartyEventRulesModuleApiCallbacks,
)
from .public_rooms_callbacks import PublicRoomsModuleApiCallbacks
class ModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
self.account_validity = AccountValidityModuleApiCallbacks()
self.spam_checker = SpamCheckerModuleApiCallbacks(hs)
self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs)
self.public_rooms = PublicRoomsModuleApiCallbacks()

View File

@@ -0,0 +1,45 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Awaitable, Callable, List, Optional, Tuple
from synapse.types import PublicRoom, ThirdPartyInstanceID
logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api
FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
[
Optional[ThirdPartyInstanceID], # network_tuple
Optional[dict], # search_filter
Optional[int], # limit
Tuple[Optional[int], Optional[str]], # bounds
bool, # forwards
],
Awaitable[List[PublicRoom]],
]
class PublicRoomsModuleApiCallbacks:
def __init__(self) -> None:
self.fetch_public_rooms_callbacks: List[FETCH_PUBLIC_ROOMS_CALLBACK] = []
def register_callbacks(
self,
fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
) -> None:
if fetch_public_rooms is not None:
self.fetch_public_rooms_callbacks.append(fetch_public_rooms)

View File

@@ -476,8 +476,9 @@ class PublicRoomListRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
requester: Optional[Requester] = None
try:
await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
@@ -516,8 +517,15 @@ class PublicRoomListRestServlet(RestServlet):
server, limit=limit, since_token=since_token
)
else:
# If a user we know made this request, pass that information to the
# public rooms list handler.
if requester is None:
from_client_mxid = None
else:
from_client_mxid = requester.user.to_string()
data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
limit=limit, since_token=since_token, from_client_mxid=from_client_mxid
)
return 200, data

View File

@@ -38,6 +38,7 @@ from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
HistoryVisibility,
JoinRules,
PublicRoomsFilterFields,
)
@@ -61,7 +62,13 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.types import (
JsonDict,
PublicRoom,
RetentionPolicy,
StrCollection,
ThirdPartyInstanceID,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import MXC_REGEX
@@ -365,21 +372,21 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
ignore_non_federatable: bool = False,
) -> List[Dict[str, Any]]:
) -> List[PublicRoom]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
Args:
network_tuple
search_filter
limit: Maxmimum number of rows to return, unlimited otherwise.
bounds: An uppoer or lower bound to apply to result set if given,
limit: Maximum number of rows to return, unlimited otherwise.
bounds: An upper or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are
excluded from result set).
forwards: true iff going forwards, going backwards otherwise
forwards: true if going forwards, going backwards otherwise
ignore_non_federatable: If true filters out non-federatable rooms.
Returns:
@@ -413,26 +420,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# Work out the bounds if we're given them, these bounds look slightly
# odd, but are designed to help query planner use indices by pulling
# out a common bound.
if bounds:
last_joined_members, last_room_id = bounds
if forwards:
where_clauses.append(
"""
joined_members <= ? AND (
joined_members < ? OR room_id < ?
)
"""
)
else:
where_clauses.append(
"""
joined_members >= ? AND (
joined_members > ? OR room_id > ?
)
"""
)
last_joined_members, last_room_id = bounds
if last_joined_members is not None:
comp = "<" if forwards else ">"
query_args += [last_joined_members, last_joined_members, last_room_id]
clause = f"joined_members {comp} ?"
query_args += [last_joined_members]
if last_room_id is not None:
clause += f" OR (joined_members = ? AND room_id {comp} ?)"
query_args += [last_joined_members, last_room_id]
where_clauses.append(clause)
if ignore_non_federatable:
where_clauses.append("is_federatable")
@@ -518,7 +517,25 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
ret_val = await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
return ret_val
def build_room_entry(room: JsonDict) -> PublicRoom:
entry = PublicRoom(
room_id=room["room_id"],
name=room["name"],
topic=room["topic"],
canonical_alias=room["canonical_alias"],
num_joined_members=room["joined_members"],
avatar_url=room["avatar"],
world_readable=room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
guest_can_join=room["guest_access"] == "can_join",
join_rule=room["join_rules"],
room_type=room["room_type"],
)
return entry
return [build_room_entry(r) for r in ret_val]
@cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]:

View File

@@ -1045,6 +1045,20 @@ class UserInfo:
locked: bool
@attr.s(auto_attribs=True, frozen=True, slots=True)
class PublicRoom:
room_id: str
num_joined_members: int
world_readable: bool
guest_can_join: bool
name: Optional[str] = None
topic: Optional[str] = None
canonical_alias: Optional[str] = None
avatar_url: Optional[str] = None
join_rule: Optional[str] = None
room_type: Optional[str] = None
class UserProfile(TypedDict):
user_id: str
display_name: Optional[str]

View File

@@ -206,3 +206,7 @@ class ExceptionBundle(Exception):
parts.append(str(e))
super().__init__("\n - ".join(parts))
self.exceptions = exceptions
def filter_none(attr: attr.Attribute, value: Any) -> bool:
return value is not None

View File

@@ -0,0 +1,261 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin, login, room
from synapse.server import HomeServer
from synapse.types import PublicRoom, ThirdPartyInstanceID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class FetchPublicRoomsTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config)
self.url = "/_matrix/client/r0/publicRooms"
return self.hs
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api()
async def module1_cb(
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
) -> List[PublicRoom]:
room1 = PublicRoom(
room_id="!one_members:module1",
num_joined_members=1,
world_readable=True,
guest_can_join=False,
)
room3 = PublicRoom(
room_id="!three_members:module1",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
room3_2 = PublicRoom(
room_id="!three_members_2:module1",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
(last_joined_members, last_room_id) = bounds
if forwards:
result = [room3_2, room3, room1]
else:
result = [room1, room3, room3_2]
if last_joined_members is not None:
if last_joined_members == 1:
if forwards:
if last_room_id == room1.room_id:
result = []
else:
result = [room1]
else:
if last_room_id == room1.room_id:
result = [room3, room3_2]
else:
result = [room1, room3, room3_2]
elif last_joined_members == 2:
if forwards:
result = [room1]
else:
result = [room3, room3_2]
elif last_joined_members == 3:
if forwards:
if last_room_id == room3.room_id:
result = [room1]
elif last_room_id == room3_2.room_id:
result = [room3, room1]
else:
if last_room_id == room3.room_id:
result = [room3_2]
elif last_room_id == room3_2.room_id:
result = []
else:
result = [room3, room3_2]
if limit is not None:
result = result[:limit]
return result
async def module2_cb(
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
) -> List[PublicRoom]:
room3 = PublicRoom(
room_id="!three_members:module2",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
(last_joined_members, last_room_id) = bounds
result = [room3]
if last_joined_members is not None:
if forwards:
if last_joined_members < 3:
result = []
elif last_joined_members == 3 and last_room_id == room3.room_id:
result = []
else:
if last_joined_members > 3:
result = []
elif last_joined_members == 3 and last_room_id == room3.room_id:
result = []
return result
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb)
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb)
user = self.register_user("alice", "pass")
token = self.login(user, "pass")
user2 = self.register_user("alice2", "pass")
token2 = self.login(user2, "pass")
user3 = self.register_user("alice3", "pass")
token3 = self.login(user3, "pass")
# Create a room with 2 people
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
self.helper.join(room_id, user2, tok=token2)
# Create a room with 3 people
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
self.helper.join(room_id, user2, tok=token2)
self.helper.join(room_id, user3, tok=token3)
def test_no_limit(self) -> None:
channel = self.make_request("GET", self.url)
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 6)
for i in range(4):
self.assertEquals(chunk[i]["num_joined_members"], 3)
self.assertEquals(chunk[4]["num_joined_members"], 2)
self.assertEquals(chunk[5]["num_joined_members"], 1)
def test_pagination_limit_1(self) -> None:
returned_three_members_rooms = set()
next_batch = None
for _i in range(4):
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 2)
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 1)
prev_batch = channel.json_body["prev_batch"]
self.assertNotIn("next_batch", channel.json_body)
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 2)
returned_three_members_rooms = set()
for _i in range(4):
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertNotIn("prev_batch", channel.json_body)
def test_pagination_limit_2(self) -> None:
returned_three_members_rooms = set()
next_batch = None
for _i in range(2):
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[1]["room_id"])
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 2)
self.assertEquals(chunk[1]["num_joined_members"], 1)
self.assertNotIn("next_batch", channel.json_body)
returned_three_members_rooms = set()
for _i in range(2):
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[1]["room_id"])
self.assertNotIn("prev_batch", channel.json_body)

View File

@@ -0,0 +1,148 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class PublicRoomsTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config)
self.url = "/_matrix/client/r0/publicRooms"
return self.hs
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self._store = homeserver.get_datastores().main
user = self.register_user("alice", "pass")
token = self.login(user, "pass")
user2 = self.register_user("alice2", "pass")
token2 = self.login(user2, "pass")
user3 = self.register_user("alice3", "pass")
token3 = self.login(user3, "pass")
# Create 10 rooms
for _ in range(3):
self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
for _ in range(3):
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
self.helper.join(room_id, user2, tok=token2)
for _ in range(4):
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
self.helper.join(room_id, user2, tok=token2)
self.helper.join(room_id, user3, tok=token3)
def test_no_limit(self) -> None:
channel = self.make_request("GET", self.url)
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 10)
def test_pagination_limit_1(self) -> None:
returned_rooms = set()
channel = None
for i in range(10):
next_batch = None if i == 0 else channel.json_body["next_batch"]
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 1)
print(chunk[0]["room_id"])
self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[0]["room_id"])
self.assertNotIn("next_batch", channel.json_body)
returned_rooms = set()
returned_rooms.add(chunk[0]["room_id"])
for i in range(9):
print(i)
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 1)
print(chunk[0]["room_id"])
self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[0]["room_id"])
def test_pagination_limit_2(self) -> None:
returned_rooms = set()
channel = None
for i in range(5):
next_batch = None if i == 0 else channel.json_body["next_batch"]
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 2)
print(chunk[0]["room_id"])
self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[0]["room_id"])
print(chunk[1]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[1]["room_id"])
self.assertNotIn("next_batch", channel.json_body)
returned_rooms = set()
returned_rooms.add(chunk[0]["room_id"])
returned_rooms.add(chunk[1]["room_id"])
for i in range(4):
print(i)
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(len(chunk), 2)
print(chunk[0]["room_id"])
self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[0]["room_id"])
print(chunk[1]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
returned_rooms.add(chunk[1]["room_id"])