Compare commits

...

14 Commits

Author SHA1 Message Date
Sean Quah
07a5623059 De-localpart {Filtering,FilteringWorkerStore}.add_user_filter()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
f98141ceb2 De-localpart {Filtering,FilteringWorkerStore}.get_user_filter()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
06f9ababc4 Fix stray write to profiles to set full_user_id
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
1a4f41b3de De-localpart ProfileWorkerStore.set_profile_avatar_url()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
1dcbff40d6 De-localpart ProfileWorkerStore.set_profile_displayname()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
76d6379727 De-localpart ProfileWorkerStore.create_profile()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
96bb319d14 De-localpart ProfileWorkerStore.get_profile_avatar_url()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
e6c582095f De-localpart ProfileWorkerStore.get_profile_displayname()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
b375e2abd9 De-localpart ProfileWorkerStore.get_profileinfo()
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
0a734d0cf2 Add background update to populate user_filters.full_user_id
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
cc90467096 Add background update to populate profiles.full_user_id
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
8182e8ad14 Add user_filters.full_user_id column
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
03ee93ee1a Add profiles.full_user_id column
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:52:42 +01:00
Sean Quah
8810abab33 Bump the schema version
Signed-off-by: Sean Quah <seanq@matrix.org>
2023-04-15 02:12:35 +01:00
27 changed files with 445 additions and 143 deletions

View File

@@ -165,16 +165,14 @@ class Filtering:
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
self, user_id: str, filter_id: Union[int, str]
) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
result = await self.store.get_user_filter(user_id, filter_id)
return FilterCollection(self._hs, result)
def add_user_filter(
self, user_localpart: str, user_filter: JsonDict
) -> Awaitable[int]:
def add_user_filter(self, user_id: str, user_filter: JsonDict) -> Awaitable[int]:
self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
return self.store.add_user_filter(user_id, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for

View File

@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation
@@ -163,9 +162,7 @@ class AccountValidityHandler:
return
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
user_display_name = await self.store.get_profile_displayname(user_id)
if user_display_name is None:
user_display_name = user_id
except StoreError:

View File

@@ -89,7 +89,7 @@ class AdminHandler:
}
# Add additional user metadata
profile = await self._store.get_profileinfo(user.localpart)
profile = await self._store.get_profileinfo(user.to_string())
threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})

View File

@@ -1756,9 +1756,7 @@ class AuthHandler:
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
user_profile_data = await self.store.get_profileinfo(
UserID.from_string(registered_user_id).localpart
)
user_profile_data = await self.store.get_profileinfo(registered_user_id)
# Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this

View File

@@ -282,8 +282,6 @@ class DeactivateAccountHandler:
Args:
user_id: ID of user to be re-activated
"""
user = UserID.from_string(user_id)
# Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id)
@@ -297,5 +295,5 @@ class DeactivateAccountHandler:
# Add the user to the directory, if necessary. Note that
# this must be done after the user is re-activated, because
# deactivated users are excluded from the user directory.
profile = await self.store.get_profileinfo(user.localpart)
profile = await self.store.get_profileinfo(user_id)
await self.user_directory_handler.handle_local_profile_change(user_id, profile)

View File

@@ -67,7 +67,7 @@ class ProfileHandler:
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user.localpart)
profileinfo = await self.store.get_profileinfo(user_id)
if profileinfo.display_name is None:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
@@ -100,7 +100,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
target_user.localpart
target_user.to_string()
)
except StoreError as e:
if e.code == 404:
@@ -147,7 +147,7 @@ class ProfileHandler:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.registration.enable_set_displayname:
profile = await self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.to_string())
if profile.display_name:
raise SynapseError(
400,
@@ -179,10 +179,10 @@ class ProfileHandler:
)
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
target_user.to_string(), displayname_to_set
)
profile = await self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.to_string())
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
@@ -197,7 +197,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
target_user.to_string()
)
except StoreError as e:
if e.code == 404:
@@ -243,7 +243,7 @@ class ProfileHandler:
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
profile = await self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.to_string())
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@@ -273,10 +273,10 @@ class ProfileHandler:
)
await self.store.set_profile_avatar_url(
target_user.localpart, avatar_url_to_set
target_user.to_string(), avatar_url_to_set
)
profile = await self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.to_string())
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
@@ -364,7 +364,8 @@ class ProfileHandler:
Codes.FORBIDDEN,
)
user = UserID.from_string(args["user_id"])
user_id = args["user_id"]
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -374,12 +375,12 @@ class ProfileHandler:
try:
if just_field is None or just_field == "displayname":
response["displayname"] = await self.store.get_profile_displayname(
user.localpart
user_id
)
if just_field is None or just_field == "avatar_url":
response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
user_id
)
except StoreError as e:
if e.code == 404:

View File

@@ -314,7 +314,7 @@ class RegistrationHandler:
approved=approved,
)
profile = await self.store.get_profileinfo(localpart)
profile = await self.store.get_profileinfo(user_id)
await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)

View File

@@ -648,7 +648,8 @@ class ModuleApi:
Returns:
The profile information (i.e. display name and avatar URL).
"""
return await self._store.get_profileinfo(localpart)
user = UserID(localpart, self._hs.hostname)
return await self._store.get_profileinfo(user.to_string())
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
"""Look up the threepids (email addresses and phone numbers) associated with the

View File

@@ -37,7 +37,7 @@ from synapse.push.push_types import (
TemplateVars,
)
from synapse.storage.databases.main.event_push_actions import EmailPushAction
from synapse.types import StateMap, UserID
from synapse.types import StateMap
from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
@@ -246,9 +246,7 @@ class Mailer:
state_by_room = {}
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
user_display_name = await self.store.get_profile_displayname(user_id)
if user_display_name is None:
user_display_name = user_id
except StoreError:

View File

@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id_int
user_id=user_id, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
@@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
user_id=user_id, user_filter=content
)
return 200, {"filter_id": str(filter_id)}

View File

@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
else:
try:
filter_collection = await self.filtering.get_user_filter(
user.localpart, filter_id
user.to_string(), filter_id
)
except StoreError as err:
if err.code != 404:

View File

@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringWorkerStore
from .filtering import FilteringStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringWorkerStore,
FilteringStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,

View File

@@ -13,22 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, UserID
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
self, user_id: str, filter_id: Union[int, str]
) -> JsonDict:
user_localpart = UserID.from_string(user_id).localpart
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -36,17 +44,31 @@ class FilteringWorkerStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
try:
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"full_user_id": user_id, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
else:
raise
return db_to_json(def_json)
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
async def add_user_filter(self, user_id: str, user_filter: JsonDict) -> int:
user_localpart = UserID.from_string(user_id).localpart
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -70,10 +92,10 @@ class FilteringWorkerStore(SQLBaseStore):
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
"INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
"VALUES(?, ?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
txn.execute(sql, (user_id, user_localpart, filter_id, bytearray(def_json)))
return filter_id
@@ -97,3 +119,67 @@ class FilteringWorkerStore(SQLBaseStore):
if attempts >= 5:
raise StoreError(500, "Couldn't generate a filter ID.")
class FilteringBackgroundUpdateStore(FilteringWorkerStore):
POPULATE_USER_FILTERS_FULL_USER_ID = "populate_user_filters_full_user_id"
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.POPULATE_USER_FILTERS_FULL_USER_ID,
self._populate_user_filters_full_user_id,
)
async def _populate_user_filters_full_user_id(
self, progress: JsonDict, batch_size: int
) -> int:
"""Populates the `user_filters.full_user_id` column.
In a future Synapse version, this column will be renamed to `user_id`, replacing
the existing `user_id` column.
Note that completion of this background update does not imply that there are no
longer any `NULL` values in `full_user_id`. Until the old `user_id` column has
been removed, Synapse may be rolled back to a previous version which does not
populate `full_user_id` after the background update has finished.
"""
def _populate_user_filters_full_user_id_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
UPDATE user_filters
SET full_user_id = '@' || user_id || ':' || ?
WHERE user_id IN (
SELECT user_id
FROM user_filters
WHERE full_user_id IS NULL
LIMIT ?
)
"""
txn.execute(sql, (self.hs.hostname, batch_size))
return txn.rowcount == 0
finished = await self.db_pool.runInteraction(
"_populate_user_filters_full_user_id_txn",
_populate_user_filters_full_user_id_txn,
)
if finished:
await self.db_pool.updates._end_background_update(
self.POPULATE_USER_FILTERS_FULL_USER_ID
)
return batch_size
class FilteringStore(FilteringBackgroundUpdateStore):
pass

View File

@@ -11,22 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
async def get_profileinfo(self, user_id: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
keyvalues={"full_user_id": user_id},
retcols=("displayname", "avatar_url"),
allow_none=True,
desc="get_profileinfo",
)
if profile is None:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
@@ -38,47 +57,138 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
async def get_profile_displayname(self, user_id: str) -> Optional[str]:
try:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"full_user_id": user_id},
retcol="displayname",
desc="get_profile_displayname",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
else:
raise
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
async def get_profile_avatar_url(self, user_id: str) -> Optional[str]:
try:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"full_user_id": user_id},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
else:
raise
async def create_profile(self, user_localpart: str) -> None:
async def create_profile(self, user_id: str) -> None:
user_localpart = UserID.from_string(user_id).localpart
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
table="profiles",
values={"user_id": user_localpart, "full_user_id": user_id},
desc="create_profile",
)
async def set_profile_displayname(
self, user_localpart: str, new_displayname: Optional[str]
self, user_id: str, new_displayname: Optional[str]
) -> None:
user_localpart = UserID.from_string(user_id).localpart
await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values={"displayname": new_displayname},
values={"full_user_id": user_id, "displayname": new_displayname},
desc="set_profile_displayname",
)
async def set_profile_avatar_url(
self, user_localpart: str, new_avatar_url: Optional[str]
self, user_id: str, new_avatar_url: Optional[str]
) -> None:
user_localpart = UserID.from_string(user_id).localpart
await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values={"avatar_url": new_avatar_url},
values={"full_user_id": user_id, "avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)
class ProfileStore(ProfileWorkerStore):
class ProfileBackgroundUpdateStore(ProfileWorkerStore):
POPULATE_PROFILES_FULL_USER_ID = "populate_profiles_full_user_id"
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.POPULATE_PROFILES_FULL_USER_ID,
self._populate_profiles_full_user_id,
)
async def _populate_profiles_full_user_id(
self, progress: JsonDict, batch_size: int
) -> int:
"""Populates the `profiles.full_user_id` column.
In a future Synapse version, this column will be renamed to `user_id`, replacing
the existing `user_id` column.
Note that completion of this background update does not imply that there are no
longer any `NULL` values in `full_user_id`. Until the old `user_id` column has
been removed, Synapse may be rolled back to a previous version which does not
populate `full_user_id` after the background update has finished.
"""
def _populate_profiles_full_user_id_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
UPDATE profiles
SET full_user_id = '@' || user_id || ':' || ?
WHERE user_id IN (
SELECT user_id
FROM profiles
WHERE full_user_id IS NULL
LIMIT ?
)
"""
txn.execute(sql, (self.hs.hostname, batch_size))
return txn.rowcount == 0
finished = await self.db_pool.runInteraction(
"_populate_profiles_full_user_id_txn",
_populate_profiles_full_user_id_txn,
)
if finished:
await self.db_pool.updates._end_background_update(
self.POPULATE_PROFILES_FULL_USER_ID
)
return batch_size
class ProfileStore(ProfileBackgroundUpdateStore):
pass

View File

@@ -2414,8 +2414,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# *obviously* the 'profiles' table uses localpart for user_id
# while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(user_id_obj.localpart, create_profile_with_displayname),
"INSERT INTO profiles(full_user_id, user_id, displayname) VALUES (?,?,?)",
(user_id, user_id_obj.localpart, create_profile_with_displayname),
)
if self.hs.config.stats.stats_enabled:

View File

@@ -383,7 +383,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
profile = await self.get_profileinfo(user_id) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
SCHEMA_VERSION = 74 # remember to update the list below when updating
SCHEMA_VERSION = 75 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the

View File

@@ -0,0 +1,35 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE profiles ADD COLUMN full_user_id TEXT;
-- Add a new constraint on the new column, mirroring the `profiles_user_id_key`
-- constraint.
ALTER TABLE ONLY profiles
ADD CONSTRAINT profiles_full_user_id_key UNIQUE (full_user_id);
-- Also ensure that new rows have the `full_user_id` field populated.
-- TODO: Move this to phase two of the migration. In a multi-worker deployment, it will
-- prevent un-updated workers from doing any UPDATEs. That is, it effectively
-- prevents rollback of Synapse to an earlier version when the column has not been
-- fully populated.
ALTER TABLE ONLY profiles
ADD CONSTRAINT profiles_full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID;
-- `profiles` can contain on the order of 10s/100s of millions of rows. We use
-- `NOT VALID` so that we do not lock the table to check existing rows. New rows will
-- still be checked.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7501, 'populate_profiles_full_user_id', '{}');

View File

@@ -0,0 +1,43 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE profiles ADD COLUMN full_user_id TEXT;
-- Add a new constraint on the new column, mirroring the `user_id` constraint.
--
-- SQLite doesn't support modifying constraints on an existing table, so it must be
-- recreated.
CREATE TABLE profiles_new(
full_user_id TEXT,
user_id TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
UNIQUE (full_user_id),
UNIQUE (user_id)
);
-- Copy the data.
INSERT INTO profiles_new (full_user_id, user_id, displayname, avatar_url)
SELECT NULL, user_id, displayname, avatar_url
FROM profiles;
-- Drop the old table.
DROP TABLE profiles;
-- Rename the table.
ALTER TABLE profiles_new RENAME TO profiles;
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7501, 'populate_profiles_full_user_id', '{}');

View File

@@ -0,0 +1,26 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE user_filters ADD COLUMN full_user_id TEXT;
-- Add a unique index on the new column, mirroring the `user_filters_unique` unique
-- index.
CREATE UNIQUE INDEX full_user_filters_unique ON user_filters (full_user_id, filter_id);
-- NB: This will lock the table for writes while the index is being built.
-- There are around 4,000,000 user_filters on matrix.org so we expect this to take
-- a couple of seconds at most.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7502, 'populate_user_filters_full_user_id', '{}');

View File

@@ -33,7 +33,8 @@ from synapse.util.frozenutils import freeze
from tests import unittest
from tests.events.test_utils import MockEvent
user_localpart = "test_user"
user_id = "@test_user:test"
user2_id = "@test_user2:test"
class FilteringTestCase(unittest.HomeserverTestCase):
@@ -437,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
user_id=user_id, user_filter=user_filter_json
)
)
presence_states = [
@@ -453,9 +454,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -467,7 +466,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
user_id=user2_id, user_filter=user_filter_json
)
)
presence_states = [
@@ -483,9 +482,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -495,16 +492,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
user_id=user_id, user_filter=user_filter_json
)
)
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events=events))
@@ -514,7 +509,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
user_id=user_id, user_filter=user_filter_json
)
)
event = MockEvent(
@@ -523,9 +518,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events))
@@ -598,7 +591,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
user_id=user_id, user_filter=user_filter_json
)
)
@@ -607,9 +600,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json,
(
self.get_success(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
)
),
)
@@ -619,14 +610,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
user_id=user_id, user_filter=user_filter_json
)
)
filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
self.assertEqual(filter.get_filter_json(), user_filter_json)

View File

@@ -67,7 +67,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
self.store.set_profile_displayname(self.frank.to_string(), "Frank")
)
displayname = self.get_success(self.handler.get_displayname(self.frank))
@@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank Jr.",
@@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank",
@@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
)
def test_set_my_name_if_disabled(self) -> None:
@@ -122,13 +122,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Setting displayname for the first time is allowed
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
self.store.set_profile_displayname(self.frank.to_string(), "Frank")
)
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank",
@@ -166,8 +166,10 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
self.get_success(self.store.create_profile("@caroline:test"))
self.get_success(
self.store.set_profile_displayname("@caroline:test", "Caroline")
)
response = self.get_success(
self.query_handlers["profile"](
@@ -184,7 +186,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
self.frank.to_string(), "http://my.server/me.png"
)
)
avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
@@ -201,7 +203,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/pic.gif",
)
@@ -215,7 +221,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/me.png",
)
@@ -229,7 +239,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
)
def test_set_my_avatar_if_disabled(self) -> None:
@@ -238,12 +252,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Setting displayname for the first time is allowed
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
self.frank.to_string(), "http://my.server/me.png"
)
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/me.png",
)

View File

@@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(email["added_at"], 0)
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self) -> None:

View File

@@ -802,9 +802,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Set avatar URL to all users, that no user has a NULL value to avoid
# different sort order between SQlite and PostreSQL
self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
self.get_success(self.store.set_profile_avatar_url("@user1:test", "mxc://url3"))
self.get_success(self.store.set_profile_avatar_url("@user2:test", "mxc://url2"))
self.get_success(self.store.set_profile_avatar_url("@admin:test", "mxc://url1"))
# order by default (name)
self._order_test([self.admin_user, user1, user2], None)
@@ -1127,7 +1127,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# set attributes for user
self.get_success(
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
self.store.set_profile_avatar_url("@user:test", "mxc://servername/mediaid")
)
self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
@@ -1257,7 +1257,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
Reproduces #12257.
"""
# Patch `self.other_user` to have an empty string as their avatar.
self.get_success(self.store.set_profile_avatar_url("user", ""))
self.get_success(self.store.set_profile_avatar_url("@user:test", ""))
# Check we can still erase them.
channel = self.make_request(
@@ -2311,7 +2311,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# set attributes for user
self.get_success(
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
self.store.set_profile_avatar_url("@user:test", "mxc://servername/mediaid")
)
self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)

View File

@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
self.store.get_user_filter(user_id="@apple:test", filter_id=0)
)
self.pump()
self.assertEqual(filter, self.EXAMPLE_FILTER)
@@ -76,7 +76,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_get_filter(self) -> None:
filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
user_id="@apple:test", user_filter=self.EXAMPLE_FILTER
)
)
self.reactor.advance(1)

View File

@@ -29,9 +29,9 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
def test_get_users_paginate(self) -> None:
self.get_success(self.store.register_user(self.user.to_string(), "pass"))
self.get_success(self.store.create_profile(self.user.localpart))
self.get_success(self.store.create_profile(self.user.to_string()))
self.get_success(
self.store.set_profile_displayname(self.user.localpart, self.displayname)
self.store.set_profile_displayname(self.user.to_string(), self.displayname)
)
users, total = self.get_success(

View File

@@ -27,36 +27,38 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.u_frank = UserID.from_string("@frank:test")
def test_displayname(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(self.store.create_profile(self.u_frank.to_string()))
self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
self.store.set_profile_displayname(self.u_frank.to_string(), "Frank")
)
self.assertEqual(
"Frank",
(
self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
self.store.get_profile_displayname(self.u_frank.to_string())
)
),
)
# test set to None
self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None)
self.store.set_profile_displayname(self.u_frank.to_string(), None)
)
self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
self.get_success(
self.store.get_profile_displayname(self.u_frank.to_string())
)
)
def test_avatar_url(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(self.store.create_profile(self.u_frank.to_string()))
self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
self.u_frank.to_string(), "http://my.site/here"
)
)
@@ -64,16 +66,18 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
"http://my.site/here",
(
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
self.store.get_profile_avatar_url(self.u_frank.to_string())
)
),
)
# test set to None
self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
self.store.set_profile_avatar_url(self.u_frank.to_string(), None)
)
self.assertIsNone(
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.to_string())
)
)