Compare commits

...

10 Commits

Author SHA1 Message Date
Erik Johnston
546c2a8f74 Fix bound_stream_token 2024-07-12 13:59:26 +01:00
Erik Johnston
224a739085 Fix sliding sync 2024-07-12 13:58:01 +01:00
Erik Johnston
2a49675813 Add test 2024-07-12 13:46:07 +01:00
Erik Johnston
0ccc29591d Review comments 2024-07-12 13:43:25 +01:00
Erik Johnston
4bf6b069ad Apply suggestions from code review
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2024-07-12 13:30:50 +01:00
Erik Johnston
940b644ca9 Test both types of token 2024-07-12 11:46:39 +01:00
Erik Johnston
31e6508626 Handle existing dodgy tokens 2024-07-12 11:40:08 +01:00
Erik Johnston
6d86303c39 Newsfile 2024-07-12 10:57:01 +01:00
Erik Johnston
d2c8d4817d Add assertion to MultiWriterTokens 2024-07-12 10:57:01 +01:00
Erik Johnston
bf0911065e Fix bug where sync could get stuck when using workers
This is because we serialized the token wrong if the instance map
contained entries from before the minimum token.
2024-07-12 09:52:43 +01:00
4 changed files with 138 additions and 10 deletions

1
changelog.d/17438.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.

View File

@@ -640,10 +640,17 @@ class SlidingSyncHandler:
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
# Then assemble the `RoomStreamToken`
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
stream=min(instance_to_max_stream_ordering_map.values()),
instance_map=immutabledict(instance_to_max_stream_ordering_map),
stream=min_stream_pos,
instance_map=immutabledict(
{
instance_name: stream_pos
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
if stream_pos > min_stream_pos
}
),
)
# Since we fetched the users room list at some point in time after the from/to

View File

@@ -20,6 +20,7 @@
#
#
import abc
import logging
import re
import string
from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
logger = logging.getLogger(__name__)
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
The values in `instance_map` must be greater than the `stream` attribute.
"""
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)
def __attrs_post_init__(self) -> None:
# Enforce that all instances have a value greater than the min stream
# position.
for i, v in self.instance_map.items():
if v <= self.stream:
raise ValueError(
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
)
@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map)
}
# Filter out any redundant entries.
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""
min_pos = min(self.stream, max_stream)
return type(self)(
stream=min(self.stream, max_stream),
stream=min_pos,
instance_map=immutabledict(
{k: min(s, max_stream) for k, s in self.instance_map.items()}
{
k: min(s, max_stream)
for k, s in self.instance_map.items()
if min(s, max_stream) > min_pos
}
),
)
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)
super().__attrs_post_init__()
@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return f"s{self.stream}"
else:
return "s%d" % (self.stream,)
@@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
@@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return str(self.stream)
else:
return str(self.stream)

View File

@@ -19,9 +19,18 @@
#
#
from typing import Type
from unittest import skipUnless
from immutabledict import immutabledict
from parameterized import parameterized_class
from synapse.api.errors import SynapseError
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
@@ -29,6 +38,7 @@ from synapse.types import (
)
from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
@@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""
token_type: Type[AbstractMultiWriterStreamToken]
def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5)
string_token = self.get_success(token.to_string(store))
if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))