mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
14 Commits
erikj/dock
...
squah/expa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07a5623059 | ||
|
|
f98141ceb2 | ||
|
|
06f9ababc4 | ||
|
|
1a4f41b3de | ||
|
|
1dcbff40d6 | ||
|
|
76d6379727 | ||
|
|
96bb319d14 | ||
|
|
e6c582095f | ||
|
|
b375e2abd9 | ||
|
|
0a734d0cf2 | ||
|
|
cc90467096 | ||
|
|
8182e8ad14 | ||
|
|
03ee93ee1a | ||
|
|
8810abab33 |
@@ -165,16 +165,14 @@ class Filtering:
|
||||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: str, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
result = await self.store.get_user_filter(user_id, filter_id)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(
|
||||
self, user_localpart: str, user_filter: JsonDict
|
||||
) -> Awaitable[int]:
|
||||
def add_user_filter(self, user_id: str, user_filter: JsonDict) -> Awaitable[int]:
|
||||
self.check_valid_filter(user_filter)
|
||||
return self.store.add_user_filter(user_localpart, user_filter)
|
||||
return self.store.add_user_filter(user_id, user_filter)
|
||||
|
||||
# TODO(paul): surely we should probably add a delete_user_filter or
|
||||
# replace_user_filter at some point? There's no REST API specified for
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
|
||||
@@ -163,9 +162,7 @@ class AccountValidityHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
user_display_name = await self.store.get_profile_displayname(user_id)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
|
||||
@@ -89,7 +89,7 @@ class AdminHandler:
|
||||
}
|
||||
|
||||
# Add additional user metadata
|
||||
profile = await self._store.get_profileinfo(user.localpart)
|
||||
profile = await self._store.get_profileinfo(user.to_string())
|
||||
threepids = await self._store.user_get_threepids(user.to_string())
|
||||
external_ids = [
|
||||
({"auth_provider": auth_provider, "external_id": external_id})
|
||||
|
||||
@@ -1756,9 +1756,7 @@ class AuthHandler:
|
||||
respond_with_html(request, 403, self._sso_account_deactivated_template)
|
||||
return
|
||||
|
||||
user_profile_data = await self.store.get_profileinfo(
|
||||
UserID.from_string(registered_user_id).localpart
|
||||
)
|
||||
user_profile_data = await self.store.get_profileinfo(registered_user_id)
|
||||
|
||||
# Store any extra attributes which will be passed in the login response.
|
||||
# Note that this is per-user so it may overwrite a previous value, this
|
||||
|
||||
@@ -282,8 +282,6 @@ class DeactivateAccountHandler:
|
||||
Args:
|
||||
user_id: ID of user to be re-activated
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
# Ensure the user is not marked as erased.
|
||||
await self.store.mark_user_not_erased(user_id)
|
||||
|
||||
@@ -297,5 +295,5 @@ class DeactivateAccountHandler:
|
||||
# Add the user to the directory, if necessary. Note that
|
||||
# this must be done after the user is re-activated, because
|
||||
# deactivated users are excluded from the user directory.
|
||||
profile = await self.store.get_profileinfo(user.localpart)
|
||||
profile = await self.store.get_profileinfo(user_id)
|
||||
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProfileHandler:
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
if self.hs.is_mine(target_user):
|
||||
profileinfo = await self.store.get_profileinfo(target_user.localpart)
|
||||
profileinfo = await self.store.get_profileinfo(user_id)
|
||||
if profileinfo.display_name is None:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
@@ -100,7 +100,7 @@ class ProfileHandler:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
target_user.to_string()
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
@@ -147,7 +147,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_displayname:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.to_string())
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
@@ -179,10 +179,10 @@ class ProfileHandler:
|
||||
)
|
||||
|
||||
await self.store.set_profile_displayname(
|
||||
target_user.localpart, displayname_to_set
|
||||
target_user.to_string(), displayname_to_set
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.to_string())
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@@ -197,7 +197,7 @@ class ProfileHandler:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
avatar_url = await self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
target_user.to_string()
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
@@ -243,7 +243,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.to_string())
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
@@ -273,10 +273,10 @@ class ProfileHandler:
|
||||
)
|
||||
|
||||
await self.store.set_profile_avatar_url(
|
||||
target_user.localpart, avatar_url_to_set
|
||||
target_user.to_string(), avatar_url_to_set
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.to_string())
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@@ -364,7 +364,8 @@ class ProfileHandler:
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
user = UserID.from_string(args["user_id"])
|
||||
user_id = args["user_id"]
|
||||
user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
@@ -374,12 +375,12 @@ class ProfileHandler:
|
||||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
response["displayname"] = await self.store.get_profile_displayname(
|
||||
user.localpart
|
||||
user_id
|
||||
)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(
|
||||
user.localpart
|
||||
user_id
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
|
||||
@@ -314,7 +314,7 @@ class RegistrationHandler:
|
||||
approved=approved,
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(localpart)
|
||||
profile = await self.store.get_profileinfo(user_id)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
user_id, profile
|
||||
)
|
||||
|
||||
@@ -648,7 +648,8 @@ class ModuleApi:
|
||||
Returns:
|
||||
The profile information (i.e. display name and avatar URL).
|
||||
"""
|
||||
return await self._store.get_profileinfo(localpart)
|
||||
user = UserID(localpart, self._hs.hostname)
|
||||
return await self._store.get_profileinfo(user.to_string())
|
||||
|
||||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
||||
"""Look up the threepids (email addresses and phone numbers) associated with the
|
||||
|
||||
@@ -37,7 +37,7 @@ from synapse.push.push_types import (
|
||||
TemplateVars,
|
||||
)
|
||||
from synapse.storage.databases.main.event_push_actions import EmailPushAction
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.types import StateMap
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.visibility import filter_events_for_client
|
||||
@@ -246,9 +246,7 @@ class Mailer:
|
||||
state_by_room = {}
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
user_display_name = await self.store.get_profile_displayname(user_id)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
|
||||
@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
||||
user_id=user_id, filter_id=filter_id_int
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code != 404:
|
||||
@@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
|
||||
set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
|
||||
|
||||
filter_id = await self.filtering.add_user_filter(
|
||||
user_localpart=target_user.localpart, user_filter=content
|
||||
user_id=user_id, user_filter=content
|
||||
)
|
||||
|
||||
return 200, {"filter_id": str(filter_id)}
|
||||
|
||||
@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
|
||||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user.localpart, filter_id
|
||||
user.to_string(), filter_id
|
||||
)
|
||||
except StoreError as err:
|
||||
if err.code != 404:
|
||||
|
||||
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
|
||||
from .event_push_actions import EventPushActionsStore
|
||||
from .events_bg_updates import EventsBackgroundUpdatesStore
|
||||
from .events_forward_extremities import EventForwardExtremitiesStore
|
||||
from .filtering import FilteringWorkerStore
|
||||
from .filtering import FilteringStore
|
||||
from .keys import KeyStore
|
||||
from .lock import LockStore
|
||||
from .media_repository import MediaRepositoryStore
|
||||
@@ -99,7 +99,7 @@ class DataStore(
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
FilteringWorkerStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
|
||||
@@ -13,22 +13,30 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import JsonDict
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class FilteringWorkerStore(SQLBaseStore):
|
||||
@cached(num_args=2)
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: str, filter_id: Union[int, str]
|
||||
) -> JsonDict:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
try:
|
||||
@@ -36,17 +44,31 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
try:
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"full_user_id": user_id, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# Fall back to the `user_id` column.
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
return db_to_json(def_json)
|
||||
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
|
||||
async def add_user_filter(self, user_id: str, user_filter: JsonDict) -> int:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
@@ -70,10 +92,10 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
filter_id = max_id + 1
|
||||
|
||||
sql = (
|
||||
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
|
||||
"VALUES(?, ?, ?)"
|
||||
"INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
|
||||
"VALUES(?, ?, ?, ?)"
|
||||
)
|
||||
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
|
||||
txn.execute(sql, (user_id, user_localpart, filter_id, bytearray(def_json)))
|
||||
|
||||
return filter_id
|
||||
|
||||
@@ -97,3 +119,67 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
|
||||
if attempts >= 5:
|
||||
raise StoreError(500, "Couldn't generate a filter ID.")
|
||||
|
||||
|
||||
class FilteringBackgroundUpdateStore(FilteringWorkerStore):
|
||||
POPULATE_USER_FILTERS_FULL_USER_ID = "populate_user_filters_full_user_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.POPULATE_USER_FILTERS_FULL_USER_ID,
|
||||
self._populate_user_filters_full_user_id,
|
||||
)
|
||||
|
||||
async def _populate_user_filters_full_user_id(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Populates the `user_filters.full_user_id` column.
|
||||
|
||||
In a future Synapse version, this column will be renamed to `user_id`, replacing
|
||||
the existing `user_id` column.
|
||||
|
||||
Note that completion of this background update does not imply that there are no
|
||||
longer any `NULL` values in `full_user_id`. Until the old `user_id` column has
|
||||
been removed, Synapse may be rolled back to a previous version which does not
|
||||
populate `full_user_id` after the background update has finished.
|
||||
"""
|
||||
|
||||
def _populate_user_filters_full_user_id_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
sql = """
|
||||
UPDATE user_filters
|
||||
SET full_user_id = '@' || user_id || ':' || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id
|
||||
FROM user_filters
|
||||
WHERE full_user_id IS NULL
|
||||
LIMIT ?
|
||||
)
|
||||
"""
|
||||
txn.execute(sql, (self.hs.hostname, batch_size))
|
||||
|
||||
return txn.rowcount == 0
|
||||
|
||||
finished = await self.db_pool.runInteraction(
|
||||
"_populate_user_filters_full_user_id_txn",
|
||||
_populate_user_filters_full_user_id_txn,
|
||||
)
|
||||
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.POPULATE_USER_FILTERS_FULL_USER_ID
|
||||
)
|
||||
|
||||
return batch_size
|
||||
|
||||
|
||||
class FilteringStore(FilteringBackgroundUpdateStore):
|
||||
pass
|
||||
|
||||
@@ -11,22 +11,41 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
||||
async def get_profileinfo(self, user_id: str) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
if profile is None:
|
||||
# Fall back to the `user_id` column.
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
@@ -38,47 +57,138 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
async def get_profile_displayname(self, user_id: str) -> Optional[str]:
|
||||
try:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# Fall back to the `user_id` column.
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
async def get_profile_avatar_url(self, user_id: str) -> Optional[str]:
|
||||
try:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# Fall back to the `user_id` column.
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def create_profile(self, user_localpart: str) -> None:
|
||||
async def create_profile(self, user_id: str) -> None:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
await self.db_pool.simple_insert(
|
||||
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart, "full_user_id": user_id},
|
||||
desc="create_profile",
|
||||
)
|
||||
|
||||
async def set_profile_displayname(
|
||||
self, user_localpart: str, new_displayname: Optional[str]
|
||||
self, user_id: str, new_displayname: Optional[str]
|
||||
) -> None:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={"displayname": new_displayname},
|
||||
values={"full_user_id": user_id, "displayname": new_displayname},
|
||||
desc="set_profile_displayname",
|
||||
)
|
||||
|
||||
async def set_profile_avatar_url(
|
||||
self, user_localpart: str, new_avatar_url: Optional[str]
|
||||
self, user_id: str, new_avatar_url: Optional[str]
|
||||
) -> None:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={"avatar_url": new_avatar_url},
|
||||
values={"full_user_id": user_id, "avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
|
||||
class ProfileStore(ProfileWorkerStore):
|
||||
class ProfileBackgroundUpdateStore(ProfileWorkerStore):
|
||||
POPULATE_PROFILES_FULL_USER_ID = "populate_profiles_full_user_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.POPULATE_PROFILES_FULL_USER_ID,
|
||||
self._populate_profiles_full_user_id,
|
||||
)
|
||||
|
||||
async def _populate_profiles_full_user_id(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Populates the `profiles.full_user_id` column.
|
||||
|
||||
In a future Synapse version, this column will be renamed to `user_id`, replacing
|
||||
the existing `user_id` column.
|
||||
|
||||
Note that completion of this background update does not imply that there are no
|
||||
longer any `NULL` values in `full_user_id`. Until the old `user_id` column has
|
||||
been removed, Synapse may be rolled back to a previous version which does not
|
||||
populate `full_user_id` after the background update has finished.
|
||||
"""
|
||||
|
||||
def _populate_profiles_full_user_id_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
sql = """
|
||||
UPDATE profiles
|
||||
SET full_user_id = '@' || user_id || ':' || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id
|
||||
FROM profiles
|
||||
WHERE full_user_id IS NULL
|
||||
LIMIT ?
|
||||
)
|
||||
"""
|
||||
txn.execute(sql, (self.hs.hostname, batch_size))
|
||||
|
||||
return txn.rowcount == 0
|
||||
|
||||
finished = await self.db_pool.runInteraction(
|
||||
"_populate_profiles_full_user_id_txn",
|
||||
_populate_profiles_full_user_id_txn,
|
||||
)
|
||||
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.POPULATE_PROFILES_FULL_USER_ID
|
||||
)
|
||||
|
||||
return batch_size
|
||||
|
||||
|
||||
class ProfileStore(ProfileBackgroundUpdateStore):
|
||||
pass
|
||||
|
||||
@@ -2414,8 +2414,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
# *obviously* the 'profiles' table uses localpart for user_id
|
||||
# while everything else uses the full mxid.
|
||||
txn.execute(
|
||||
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||
(user_id_obj.localpart, create_profile_with_displayname),
|
||||
"INSERT INTO profiles(full_user_id, user_id, displayname) VALUES (?,?,?)",
|
||||
(user_id, user_id_obj.localpart, create_profile_with_displayname),
|
||||
)
|
||||
|
||||
if self.hs.config.stats.stats_enabled:
|
||||
|
||||
@@ -383,7 +383,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
for user_id in users_to_work_on:
|
||||
if await self.should_include_local_user_in_dir(user_id):
|
||||
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
|
||||
profile = await self.get_profileinfo(user_id) # type: ignore[attr-defined]
|
||||
await self.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 74 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 75 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE profiles ADD COLUMN full_user_id TEXT;
|
||||
|
||||
-- Add a new constraint on the new column, mirroring the `profiles_user_id_key`
|
||||
-- constraint.
|
||||
ALTER TABLE ONLY profiles
|
||||
ADD CONSTRAINT profiles_full_user_id_key UNIQUE (full_user_id);
|
||||
|
||||
-- Also ensure that new rows have the `full_user_id` field populated.
|
||||
-- TODO: Move this to phase two of the migration. In a multi-worker deployment, it will
|
||||
-- prevent un-updated workers from doing any UPDATEs. That is, it effectively
|
||||
-- prevents rollback of Synapse to an earlier version when the column has not been
|
||||
-- fully populated.
|
||||
ALTER TABLE ONLY profiles
|
||||
ADD CONSTRAINT profiles_full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID;
|
||||
-- `profiles` can contain on the order of 10s/100s of millions of rows. We use
|
||||
-- `NOT VALID` so that we do not lock the table to check existing rows. New rows will
|
||||
-- still be checked.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7501, 'populate_profiles_full_user_id', '{}');
|
||||
@@ -0,0 +1,43 @@
|
||||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE profiles ADD COLUMN full_user_id TEXT;
|
||||
|
||||
-- Add a new constraint on the new column, mirroring the `user_id` constraint.
|
||||
--
|
||||
-- SQLite doesn't support modifying constraints on an existing table, so it must be
|
||||
-- recreated.
|
||||
CREATE TABLE profiles_new(
|
||||
full_user_id TEXT,
|
||||
user_id TEXT NOT NULL,
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
UNIQUE (full_user_id),
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
|
||||
-- Copy the data.
|
||||
INSERT INTO profiles_new (full_user_id, user_id, displayname, avatar_url)
|
||||
SELECT NULL, user_id, displayname, avatar_url
|
||||
FROM profiles;
|
||||
|
||||
-- Drop the old table.
|
||||
DROP TABLE profiles;
|
||||
|
||||
-- Rename the table.
|
||||
ALTER TABLE profiles_new RENAME TO profiles;
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7501, 'populate_profiles_full_user_id', '{}');
|
||||
@@ -0,0 +1,26 @@
|
||||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE user_filters ADD COLUMN full_user_id TEXT;
|
||||
|
||||
-- Add a unique index on the new column, mirroring the `user_filters_unique` unique
|
||||
-- index.
|
||||
CREATE UNIQUE INDEX full_user_filters_unique ON user_filters (full_user_id, filter_id);
|
||||
-- NB: This will lock the table for writes while the index is being built.
|
||||
-- There are around 4,000,000 user_filters on matrix.org so we expect this to take
|
||||
-- a couple of seconds at most.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7502, 'populate_user_filters_full_user_id', '{}');
|
||||
@@ -33,7 +33,8 @@ from synapse.util.frozenutils import freeze
|
||||
from tests import unittest
|
||||
from tests.events.test_utils import MockEvent
|
||||
|
||||
user_localpart = "test_user"
|
||||
user_id = "@test_user:test"
|
||||
user2_id = "@test_user2:test"
|
||||
|
||||
|
||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
@@ -437,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
presence_states = [
|
||||
@@ -453,9 +454,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@@ -467,7 +466,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
||||
user_id=user2_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
presence_states = [
|
||||
@@ -483,9 +482,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@@ -495,16 +492,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||
@@ -514,7 +509,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
event = MockEvent(
|
||||
@@ -523,9 +518,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events))
|
||||
@@ -598,7 +591,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
filter_id = self.get_success(
|
||||
self.filtering.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
|
||||
@@ -607,9 +600,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json,
|
||||
(
|
||||
self.get_success(
|
||||
self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=0
|
||||
)
|
||||
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -619,14 +610,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
|
||||
filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_get_my_name(self) -> None:
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
self.store.set_profile_displayname(self.frank.to_string(), "Frank")
|
||||
)
|
||||
|
||||
displayname = self.get_success(self.handler.get_displayname(self.frank))
|
||||
@@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank Jr.",
|
||||
@@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
@@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
@@ -122,13 +122,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
self.store.set_profile_displayname(self.frank.to_string(), "Frank")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
@@ -166,8 +166,10 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self) -> None:
|
||||
self.get_success(self.store.create_profile("caroline"))
|
||||
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
||||
self.get_success(self.store.create_profile("@caroline:test"))
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname("@caroline:test", "Caroline")
|
||||
)
|
||||
|
||||
response = self.get_success(
|
||||
self.query_handlers["profile"](
|
||||
@@ -184,7 +186,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_my_avatar(self) -> None:
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
self.frank.to_string(), "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
|
||||
@@ -201,7 +203,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
@@ -215,7 +221,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
@@ -229,7 +239,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def test_set_my_avatar_if_disabled(self) -> None:
|
||||
@@ -238,12 +252,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# Setting displayname for the first time is allowed
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
self.frank.to_string(), "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||
self.assertEqual(email["added_at"], 0)
|
||||
|
||||
# Check that the displayname was assigned
|
||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
||||
displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
|
||||
self.assertEqual(displayname, "Bobberino")
|
||||
|
||||
def test_can_register_admin_user(self) -> None:
|
||||
|
||||
@@ -802,9 +802,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Set avatar URL to all users, that no user has a NULL value to avoid
|
||||
# different sort order between SQlite and PostreSQL
|
||||
self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
|
||||
self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
|
||||
self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
|
||||
self.get_success(self.store.set_profile_avatar_url("@user1:test", "mxc://url3"))
|
||||
self.get_success(self.store.set_profile_avatar_url("@user2:test", "mxc://url2"))
|
||||
self.get_success(self.store.set_profile_avatar_url("@admin:test", "mxc://url1"))
|
||||
|
||||
# order by default (name)
|
||||
self._order_test([self.admin_user, user1, user2], None)
|
||||
@@ -1127,7 +1127,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# set attributes for user
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
|
||||
self.store.set_profile_avatar_url("@user:test", "mxc://servername/mediaid")
|
||||
)
|
||||
self.get_success(
|
||||
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
|
||||
@@ -1257,7 +1257,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
||||
Reproduces #12257.
|
||||
"""
|
||||
# Patch `self.other_user` to have an empty string as their avatar.
|
||||
self.get_success(self.store.set_profile_avatar_url("user", ""))
|
||||
self.get_success(self.store.set_profile_avatar_url("@user:test", ""))
|
||||
|
||||
# Check we can still erase them.
|
||||
channel = self.make_request(
|
||||
@@ -2311,7 +2311,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# set attributes for user
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
|
||||
self.store.set_profile_avatar_url("@user:test", "mxc://servername/mediaid")
|
||||
)
|
||||
self.get_success(
|
||||
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
|
||||
|
||||
@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.get_success(
|
||||
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.store.get_user_filter(user_id="@apple:test", filter_id=0)
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||
@@ -76,7 +76,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_filter(self) -> None:
|
||||
filter_id = self.get_success(
|
||||
self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
user_id="@apple:test", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
|
||||
@@ -29,9 +29,9 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_get_users_paginate(self) -> None:
|
||||
self.get_success(self.store.register_user(self.user.to_string(), "pass"))
|
||||
self.get_success(self.store.create_profile(self.user.localpart))
|
||||
self.get_success(self.store.create_profile(self.user.to_string()))
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||
self.store.set_profile_displayname(self.user.to_string(), self.displayname)
|
||||
)
|
||||
|
||||
users, total = self.get_success(
|
||||
|
||||
@@ -27,36 +27,38 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.u_frank = UserID.from_string("@frank:test")
|
||||
|
||||
def test_displayname(self) -> None:
|
||||
self.get_success(self.store.create_profile(self.u_frank.localpart))
|
||||
self.get_success(self.store.create_profile(self.u_frank.to_string()))
|
||||
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
self.store.set_profile_displayname(self.u_frank.to_string(), "Frank")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
"Frank",
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
||||
self.store.get_profile_displayname(self.u_frank.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# test set to None
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.u_frank.localpart, None)
|
||||
self.store.set_profile_displayname(self.u_frank.to_string(), None)
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.u_frank.to_string())
|
||||
)
|
||||
)
|
||||
|
||||
def test_avatar_url(self) -> None:
|
||||
self.get_success(self.store.create_profile(self.u_frank.localpart))
|
||||
self.get_success(self.store.create_profile(self.u_frank.to_string()))
|
||||
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
self.u_frank.to_string(), "http://my.site/here"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -64,16 +66,18 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||
"http://my.site/here",
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.u_frank.localpart)
|
||||
self.store.get_profile_avatar_url(self.u_frank.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# test set to None
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
|
||||
self.store.set_profile_avatar_url(self.u_frank.to_string(), None)
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.u_frank.to_string())
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user