mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
2 Commits
patch-1
...
mv/test-ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
941ab4b483 | ||
|
|
13113d8378 |
@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
"synapse.util.caches.descriptors.CachedFunction.__call__"
|
||||
) or fullname.startswith(
|
||||
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||
):
|
||||
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
|
||||
|
||||
|
||||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
"""Fixes the `_CachedFunction.__call__` signature to be correct.
|
||||
"""Fixes the `CachedFunction.__call__` signature to be correct.
|
||||
|
||||
It already has *almost* the correct signature, except:
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ from synapse.types import (
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import CachedFunction, cached
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -836,6 +836,20 @@ class ModuleApi:
|
||||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def invalidate_cache(
|
||||
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
|
||||
) -> None:
|
||||
cached_func.invalidate(keys)
|
||||
await self._store.send_invalidation_to_replication(
|
||||
cached_func.__qualname__,
|
||||
keys,
|
||||
)
|
||||
|
||||
def register_cached_function(self, cached_func: CachedFunction) -> None:
|
||||
self._store.register_external_cached_function(
|
||||
cached_func.__qualname__, cached_func
|
||||
)
|
||||
|
||||
def complete_sso_login(
|
||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||
) -> None:
|
||||
|
||||
@@ -95,7 +95,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
|
||||
def _attempt_to_invalidate_cache(
|
||||
self, cache_name: str, key: Optional[Collection[Any]]
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||
where they may not have the cache.
|
||||
@@ -115,7 +115,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
except AttributeError:
|
||||
# We probably haven't pulled in the cache in this worker,
|
||||
# which is fine.
|
||||
return
|
||||
return False
|
||||
|
||||
if key is None:
|
||||
cache.invalidate_all()
|
||||
@@ -125,6 +125,8 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||
invalidate_method(tuple(key))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||
"""
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.util.caches.descriptors import _CachedFunction
|
||||
from synapse.util.caches.descriptors import CachedFunction
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
self.external_cached_functions = {}
|
||||
|
||||
def register_external_cached_function(self, cache_name, func):
|
||||
self.external_cached_functions[cache_name] = func
|
||||
|
||||
async def get_all_updated_caches(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
@@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
members_changed = set(row.keys[1:])
|
||||
self._invalidate_state_caches(room_id, members_changed)
|
||||
else:
|
||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
res = self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
if not res:
|
||||
external_func = self.external_cached_functions[row.cache_func]
|
||||
if external_func:
|
||||
external_func.invalidate(row.keys)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
return
|
||||
|
||||
cache_func.invalidate(keys)
|
||||
await self.db_pool.runInteraction(
|
||||
"invalidate_cache_and_stream",
|
||||
self._send_invalidation_to_replication,
|
||||
await self.send_invalidation_to_replication(
|
||||
cache_func.__name__,
|
||||
keys,
|
||||
)
|
||||
@@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def _invalidate_cache_and_stream(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
cache_func: _CachedFunction,
|
||||
cache_func: CachedFunction,
|
||||
keys: Tuple[Any, ...],
|
||||
) -> None:
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
@@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||
|
||||
def _invalidate_all_cache_and_stream(
|
||||
self, txn: LoggingTransaction, cache_func: _CachedFunction
|
||||
self, txn: LoggingTransaction, cache_func: CachedFunction
|
||||
) -> None:
|
||||
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
@@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
txn, CURRENT_STATE_CACHE_NAME, [room_id]
|
||||
)
|
||||
|
||||
async def send_invalidation_to_replication(self, cache_name, keys):
|
||||
await self.db_pool.runInteraction(
|
||||
"send_invalidation_to_replication",
|
||||
self._send_invalidation_to_replication,
|
||||
cache_name,
|
||||
keys,
|
||||
)
|
||||
|
||||
def _send_invalidation_to_replication(
|
||||
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
||||
) -> None:
|
||||
|
||||
@@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any]
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class _CachedFunction(Generic[F]):
|
||||
class CachedFunction(Generic[F]):
|
||||
invalidate: Any = None
|
||||
invalidate_all: Any = None
|
||||
prefill: Any = None
|
||||
@@ -239,7 +239,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
return ret2
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
wrapped = cast(CachedFunction, _wrapped)
|
||||
wrapped.cache = cache
|
||||
obj.__dict__[self.orig.__name__] = wrapped
|
||||
|
||||
@@ -358,7 +358,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
return make_deferred_yieldable(ret)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
wrapped = cast(CachedFunction, _wrapped)
|
||||
|
||||
if self.num_args == 1:
|
||||
assert not self.tree
|
||||
@@ -577,7 +577,7 @@ def cached(
|
||||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
prune_unread_entries: bool = True,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
) -> Callable[[F], CachedFunction[F]]:
|
||||
func = lambda orig: DeferredCacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
@@ -589,12 +589,12 @@ def cached(
|
||||
prune_unread_entries=prune_unread_entries,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
return cast(Callable[[F], CachedFunction[F]], func)
|
||||
|
||||
|
||||
def cachedList(
|
||||
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
) -> Callable[[F], CachedFunction[F]]:
|
||||
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. One of the arguments
|
||||
@@ -630,7 +630,7 @@ def cachedList(
|
||||
num_args=num_args,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
return cast(Callable[[F], CachedFunction[F]], func)
|
||||
|
||||
|
||||
def _get_cache_key_builder(
|
||||
|
||||
240
tests/replication/test_account_validity.py
Normal file
240
tests/replication/test_account_validity.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
|
||||
import synapse
|
||||
from synapse.module_api import DatabasePool, LoggingTransaction, ModuleApi, cached
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import ThreadedMemoryReactorClock, make_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockAccountValidityStore:
|
||||
def __init__(
|
||||
self,
|
||||
api: ModuleApi,
|
||||
):
|
||||
self._api = api
|
||||
|
||||
api.register_cached_function(self.is_user_expired)
|
||||
|
||||
async def create_db(self):
|
||||
def create_table_txn(txn: LoggingTransaction):
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS mock_account_validity(
|
||||
user_id TEXT PRIMARY KEY,
|
||||
expired BOOLEAN NOT NULL
|
||||
)
|
||||
""",
|
||||
(),
|
||||
)
|
||||
|
||||
await self._api.run_db_interaction(
|
||||
"account_validity_create_table",
|
||||
create_table_txn,
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def is_user_expired(self, user_id: str) -> Optional[bool]:
|
||||
def get_expiration_for_user_txn(txn: LoggingTransaction):
|
||||
return DatabasePool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="mock_account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="expired",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return await self._api.run_db_interaction(
|
||||
"get_expiration_for_user",
|
||||
get_expiration_for_user_txn,
|
||||
)
|
||||
|
||||
async def on_user_registration(self, user_id: str) -> None:
|
||||
def add_valid_user_txn(txn: LoggingTransaction):
|
||||
txn.execute(
|
||||
"INSERT INTO mock_account_validity (user_id, expired) VALUES (?, ?)",
|
||||
(user_id, False),
|
||||
)
|
||||
|
||||
await self._api.run_db_interaction(
|
||||
"account_validity_add_valid_user",
|
||||
add_valid_user_txn,
|
||||
)
|
||||
|
||||
async def set_expired(self, user_id: str, expired: bool = True) -> None:
|
||||
def set_expired_user_txn(txn: LoggingTransaction):
|
||||
txn.execute(
|
||||
"UPDATE mock_account_validity SET expired = ? WHERE user_id = ?",
|
||||
(
|
||||
expired,
|
||||
user_id,
|
||||
),
|
||||
)
|
||||
|
||||
await self._api.run_db_interaction(
|
||||
"account_validity_set_expired_user",
|
||||
set_expired_user_txn,
|
||||
)
|
||||
|
||||
await self._api.invalidate_cache(self.is_user_expired, (user_id,))
|
||||
|
||||
|
||||
class MockAccountValidity:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
api: ModuleApi,
|
||||
):
|
||||
self._api = api
|
||||
|
||||
self._store = MockAccountValidityStore(api)
|
||||
|
||||
ensureDeferred(self._store.create_db())
|
||||
cast(ThreadedMemoryReactorClock, api._hs.get_reactor()).pump([0.0])
|
||||
|
||||
self._api.register_account_validity_callbacks(
|
||||
is_user_expired=self.is_user_expired,
|
||||
on_user_registration=self.on_user_registration,
|
||||
)
|
||||
|
||||
async def is_user_expired(self, user_id: str) -> Optional[bool]:
|
||||
return await self._store.is_user_expired(user_id)
|
||||
|
||||
async def on_user_registration(self, user_id: str) -> None:
|
||||
await self._store.on_user_registration(user_id)
|
||||
|
||||
|
||||
class WorkerAccountValidityTestCase(BaseMultiWorkerStreamTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
synapse.rest.client.account.register_servlets,
|
||||
synapse.rest.client.login.register_servlets,
|
||||
synapse.rest.client.register.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
config["modules"] = [
|
||||
{
|
||||
"module": __name__ + ".MockAccountValidity",
|
||||
}
|
||||
]
|
||||
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = super().make_homeserver(reactor, clock)
|
||||
module_api = hs.get_module_api()
|
||||
for module, config in hs.config.modules.loaded_modules:
|
||||
self.module = module(config=config, api=module_api)
|
||||
logger.info("Loaded module %s", self.module)
|
||||
return hs
|
||||
|
||||
def make_worker_hs(
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
|
||||
) -> HomeServer:
|
||||
hs = super().make_worker_hs(worker_app, extra_config=extra_config)
|
||||
module_api = hs.get_module_api()
|
||||
for module, config in hs.config.modules.loaded_modules:
|
||||
# Do not store the module in self here since we want to expire the user
|
||||
# from the main worker and see if it get properly replicated to the other one.
|
||||
module(config=config, api=module_api)
|
||||
logger.info("Loaded module %s", self.module)
|
||||
return hs
|
||||
|
||||
def _create_and_check_user(self):
|
||||
self.register_user("user", "pass")
|
||||
user_id = "@user:test"
|
||||
token = self.login("user", "pass")
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||
|
||||
return user_id, token
|
||||
|
||||
def test_account_validity(self):
|
||||
user_id, token = self._create_and_check_user()
|
||||
|
||||
self.get_success_or_raise(self.module._store.set_expired(user_id))
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
self.assertEqual(channel.code, 403)
|
||||
|
||||
self.get_success_or_raise(self.module._store.set_expired(user_id, False))
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
def test_account_validity_with_worker_and_cache(self):
|
||||
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||
worker_site = self._hs_to_site[worker_hs]
|
||||
|
||||
user_id, token = self._create_and_check_user()
|
||||
|
||||
# check than the user is valid on the other worker too
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
worker_site,
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Expires user on the main worker, and check its state on the other worker
|
||||
self.get_success_or_raise(self.module._store.set_expired(user_id))
|
||||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
worker_site,
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
self.assertEqual(channel.code, 403)
|
||||
|
||||
# Un-expires user on the main worker, and check its state on the other worker
|
||||
self.get_success_or_raise(self.module._store.set_expired(user_id, False))
|
||||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
worker_site,
|
||||
"GET",
|
||||
"/_matrix/client/v3/account/whoami",
|
||||
access_token=token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
88
tests/replication/test_module_cache_invalidation.py
Normal file
88
tests/replication/test_module_cache_invalidation.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging, time
|
||||
|
||||
import synapse
|
||||
from synapse.module_api import ModuleApi, cached
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FIRST_VALUE = "one"
|
||||
SECOND_VALUE = "two"
|
||||
|
||||
KEY = "mykey"
|
||||
|
||||
class TestCache:
|
||||
current_value = FIRST_VALUE
|
||||
|
||||
@cached()
|
||||
async def cached_function(self, user_id: str) -> str:
|
||||
print(self.current_value)
|
||||
return self.current_value
|
||||
|
||||
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
|
||||
|
||||
def test_module_cache_full_invalidation(self):
|
||||
main_cache = TestCache()
|
||||
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
||||
|
||||
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
worker_cache = TestCache()
|
||||
worker_hs.get_module_api().register_cached_function(worker_cache.cached_function)
|
||||
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
|
||||
main_cache.current_value = SECOND_VALUE
|
||||
worker_cache.current_value = SECOND_VALUE
|
||||
# No invalidation yet, should return the cached value on both the main process and the worker
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Full invalidation on the main process, should be replicated on the worker that
|
||||
# should returned the updated value too
|
||||
self.get_success(
|
||||
self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,))
|
||||
)
|
||||
|
||||
self.reactor.advance(1)
|
||||
|
||||
self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
|
||||
# def test_module_cache_local_invalidation_only(self):
|
||||
# main_cache = TestCache()
|
||||
# worker_cache = TestCache()
|
||||
|
||||
# worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
# self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
# self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
|
||||
# main_cache.current_value = SECOND_VALUE
|
||||
# worker_cache.current_value = SECOND_VALUE
|
||||
# # No local invalidation yet, should return the cached value on both the main process and the worker
|
||||
# self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
# self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
|
||||
# # local invalidation on the main process, worker should still return the cached value
|
||||
# main_cache.cached_function.invalidate((KEY,))
|
||||
# self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
# self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|
||||
Reference in New Issue
Block a user