Compare commits

...

7 Commits

Author SHA1 Message Date
Erik Johnston
d48caa1800 Newsfile 2024-07-03 15:24:17 +01:00
Erik Johnston
e3d5988aa6 Allow enabling sliding sync per-user 2024-07-03 15:23:37 +01:00
Erik Johnston
0721903ee9 Newsfile 2024-07-03 14:20:37 +01:00
Erik Johnston
ff16dd5af5 Use per-user feature flags for MSC3881 2024-07-03 14:20:37 +01:00
Erik Johnston
8ce087dd00 Optionally check for auth in /versions 2024-07-03 14:20:37 +01:00
Erik Johnston
e104003b4f Remove unused feature 2024-07-03 14:20:37 +01:00
Erik Johnston
e61b6a1cb3 Add check for per-user feature flag 2024-07-03 14:20:37 +01:00
12 changed files with 271 additions and 44 deletions

1
changelog.d/17392.misc Normal file
View File

@@ -0,0 +1 @@
Finish up work to allow per-user feature flags.

1
changelog.d/17393.misc Normal file
View File

@@ -0,0 +1 @@
Allow enabling sliding sync per-user.

View File

@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from typing_extensions import Protocol
@@ -28,6 +28,9 @@ from synapse.appservice import ApplicationService
from synapse.http.site import SynapseRequest
from synapse.types import Requester
if TYPE_CHECKING:
from synapse.rest.admin.experimental_features import ExperimentalFeature
# guests always get this device id.
GUEST_DEVICE_ID = "guest_device"
@@ -87,6 +90,19 @@ class Auth(Protocol):
AuthError if access is denied for the user in the access token
"""
async def get_user_by_req_experimental_feature(
self,
request: SynapseRequest,
feature: "ExperimentalFeature",
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
"""Like `get_user_by_req`, except also checks if the user has access to
the experimental feature. If they don't returns a 404 unrecognized
request.
"""
async def validate_appservice_can_control_user_id(
self, app_service: ApplicationService, user_id: str
) -> None:

View File

@@ -28,6 +28,7 @@ from synapse.api.errors import (
Codes,
InvalidClientTokenError,
MissingClientTokenError,
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
@@ -38,8 +39,10 @@ from . import GUEST_DEVICE_ID
from .base import BaseAuth
if TYPE_CHECKING:
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -106,6 +109,32 @@ class InternalAuth(BaseAuth):
parent_span.set_tag("appservice_id", requester.app_service.id)
return requester
async def get_user_by_req_experimental_feature(
self,
request: SynapseRequest,
feature: "ExperimentalFeature",
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
try:
requester = await self.get_user_by_req(
request,
allow_guest=allow_guest,
allow_expired=allow_expired,
allow_locked=allow_locked,
)
if await self.store.is_feature_enabled(requester.user.to_string(), feature):
return requester
raise UnrecognizedRequestError(code=404)
except (AuthError, InvalidClientTokenError):
if feature.is_globally_enabled(self.hs.config):
# If its globally enabled then return the auth error
raise
raise UnrecognizedRequestError(code=404)
@cancellable
async def _wrapped_get_user_by_req(
self,

View File

@@ -40,6 +40,7 @@ from synapse.api.errors import (
OAuthInsufficientScopeError,
StoreError,
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@@ -48,6 +49,7 @@ from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
if TYPE_CHECKING:
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -245,6 +247,32 @@ class MSC3861DelegatedAuth(BaseAuth):
return requester
async def get_user_by_req_experimental_feature(
self,
request: SynapseRequest,
feature: "ExperimentalFeature",
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
try:
requester = await self.get_user_by_req(
request,
allow_guest=allow_guest,
allow_expired=allow_expired,
allow_locked=allow_locked,
)
if await self.store.is_feature_enabled(requester.user.to_string(), feature):
return requester
raise UnrecognizedRequestError(code=404)
except (AuthError, InvalidClientTokenError):
if feature.is_globally_enabled(self.hs.config):
# If its globally enabled then return the auth error
raise
raise UnrecognizedRequestError(code=404)
async def get_user_by_access_token(
self,
token: str,

View File

@@ -31,7 +31,9 @@ from synapse.rest.admin import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
from typing_extensions import assert_never
from synapse.server import HomeServer, HomeServerConfig
class ExperimentalFeature(str, Enum):
@@ -39,8 +41,16 @@ class ExperimentalFeature(str, Enum):
Currently supported per-user features
"""
MSC3026 = "msc3026"
MSC3881 = "msc3881"
MSC3575 = "msc3575"
def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
if self is ExperimentalFeature.MSC3881:
return config.experimental.msc3881_enabled
if self is ExperimentalFeature.MSC3575:
return config.experimental.msc3575_enabled
assert_never(self)
class ExperimentalFeaturesRestServlet(RestServlet):

View File

@@ -32,6 +32,7 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client._base import client_patterns
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.types import JsonDict
@@ -49,20 +50,22 @@ class PushersRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()
pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
user.to_string()
msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)
pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)
pusher_dicts = [p.as_dict() for p in pushers]
for pusher in pusher_dicts:
if self._msc3881_enabled:
if msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"]
@@ -80,11 +83,15 @@ class PushersSetRestServlet(RestServlet):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()
msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)
content = parse_json_object_from_request(request)
@@ -95,7 +102,7 @@ class PushersSetRestServlet(RestServlet):
and content["kind"] is None
):
await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
content["app_id"], content["pushkey"], user_id=user_id
)
return 200, {}
@@ -120,19 +127,19 @@ class PushersSetRestServlet(RestServlet):
append = content["append"]
enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"]
if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
pushkey=content["pushkey"],
not_user_id=user.to_string(),
not_user_id=user_id,
)
try:
await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(),
user_id=user_id,
kind=content["kind"],
app_id=content["app_id"],
app_display_name=content["app_display_name"],

View File

@@ -53,6 +53,7 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict, Requester, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util import json_decoder
@@ -673,7 +674,9 @@ class SlidingSyncE2eeRestServlet(RestServlet):
)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req_experimental_feature(
request, allow_guest=True, feature=ExperimentalFeature.MSC3575
)
user = requester.user
device_id = requester.device_id
@@ -873,7 +876,10 @@ class SlidingSyncRestServlet(RestServlet):
self.event_serializer = hs.get_event_client_serializer()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req_experimental_feature(
request, allow_guest=True, feature=ExperimentalFeature.MSC3575
)
user = requester.user
device_id = requester.device_id
@@ -1051,6 +1057,5 @@ class SlidingSyncRestServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
if hs.config.experimental.msc3575_enabled:
SlidingSyncRestServlet(hs).register(http_server)
SlidingSyncE2eeRestServlet(hs).register(http_server)
SlidingSyncRestServlet(hs).register(http_server)
SlidingSyncE2eeRestServlet(hs).register(http_server)

View File

@@ -25,11 +25,11 @@ import logging
import re
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -45,6 +45,8 @@ class VersionsRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
# Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = (
@@ -60,7 +62,17 @@ class VersionsRestServlet(RestServlet):
in self.config.room.encryption_enabled_by_default_for_room_presets
)
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
msc3881_enabled = self.config.experimental.msc3881_enabled
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
msc3881_enabled = await self.store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)
return (
200,
{
@@ -124,7 +136,7 @@ class VersionsRestServlet(RestServlet):
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
"org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
"org.matrix.msc3881": msc3881_enabled,
# Adds support for filtering /messages by event relation.
"org.matrix.msc3874": self.config.experimental.msc3874_enabled,
# Adds support for simple HTTP rendezvous as per MSC3886

View File

@@ -21,7 +21,11 @@
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached
@@ -73,12 +77,54 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
features:
pairs of features and True/False for whether the feature should be enabled
"""
for feature, enabled in features.items():
await self.db_pool.simple_upsert(
table="per_user_experimental_features",
keyvalues={"feature": feature, "user_id": user},
values={"enabled": enabled},
insertion_values={"user_id": user, "feature": feature},
)
await self.invalidate_cache_and_stream("list_enabled_features", (user,))
def set_features_for_user_txn(txn: LoggingTransaction) -> None:
for feature, enabled in features.items():
self.db_pool.simple_upsert_txn(
txn,
table="per_user_experimental_features",
keyvalues={"feature": feature, "user_id": user},
values={"enabled": enabled},
insertion_values={"user_id": user, "feature": feature},
)
self._invalidate_cache_and_stream(
txn, self.is_feature_enabled, (user, feature)
)
self._invalidate_cache_and_stream(txn, self.list_enabled_features, (user,))
return await self.db_pool.runInteraction(
"set_features_for_user", set_features_for_user_txn
)
@cached()
async def is_feature_enabled(
self, user_id: str, feature: "ExperimentalFeature"
) -> bool:
"""
Checks to see if a given feature is enabled for the user
Args:
user_id: the user to be queried on
feature: the feature in question
Returns:
True if the feature is enabled, False if it is not or if the feature was
not found.
"""
if feature.is_globally_enabled(self.hs.config):
return True
# if it's not enabled globally, check if it is enabled per-user
res = await self.db_pool.simple_select_one_onecol(
table="per_user_experimental_features",
keyvalues={"user_id": user_id, "feature": feature},
retcol="enabled",
allow_none=True,
desc="get_feature_enabled",
)
# None and false are treated the same
db_enabled = bool(res)
return db_enabled

View File

@@ -26,7 +26,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -42,6 +43,7 @@ class HTTPPusherTests(HomeserverTestCase):
receipts.register_servlets,
push_rule.register_servlets,
pusher.register_servlets,
versions.register_servlets,
]
user_id = True
hijack_auth = False
@@ -969,6 +971,84 @@ class HTTPPusherTests(HomeserverTestCase):
lookup_result.device_id,
)
def test_device_id_feature_flag(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests when feature is enabled for the user
"""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)
# Look up the user info for the access token so we can compare the device ID.
store = self.hs.get_datastores().main
lookup_result = self.get_success(store.get_user_by_access_token(access_token))
assert lookup_result is not None
# Check field is not there before we enable the feature flag
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertNotIn(
"org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
)
self.get_success(
store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
)
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)
def test_msc3881_client_versions_flag(self) -> None:
"""Tests that MSC3881 only appears in /versions if user has it enabled."""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Check feature is disabled in /versions
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])
# Enable feature for user
self.get_success(
self.hs.get_datastores().main.set_features_for_user(
user_id, {ExperimentalFeature.MSC3881: True}
)
)
# Check feature is now enabled in /versions for user
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])
@override_config({"push": {"jitter_delay": "10s"}})
def test_jitter(self) -> None:
"""Tests that enabling jitter actually delays sending push."""

View File

@@ -384,7 +384,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
"PUT",
url,
content={
"features": {"msc3026": True, "msc3881": True},
"features": {"msc3881": True},
},
access_token=self.admin_user_tok,
)
@@ -399,10 +399,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
self.assertEqual(
True,
channel.json_body["features"]["msc3026"],
)
self.assertEqual(
True,
channel.json_body["features"]["msc3881"],
@@ -413,7 +409,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
url,
content={"features": {"msc3026": False}},
content={"features": {"msc3881": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
@@ -429,10 +425,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(
False,
channel.json_body["features"]["msc3026"],
)
self.assertEqual(
True,
channel.json_body["features"]["msc3881"],
)
@@ -441,7 +433,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
url,
content={"features": {"msc3026": False}},
content={"features": {"msc3881": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)