Compare commits

...

14 Commits

Author SHA1 Message Date
David Robertson
1ad179183f Tweak test 2023-03-10 17:54:22 +00:00
David Robertson
bc619d9565 New test 2023-03-10 16:34:37 +00:00
David Robertson
2bf5878010 change_membership: expose json_body 2023-03-10 16:34:19 +00:00
David Robertson
722ff49567 Changelog 2023-03-10 16:34:19 +00:00
David Robertson
6daadb9bd2 Avoid state lookups for events we'll ignore 2023-03-10 16:34:19 +00:00
David Robertson
3ad2f5f426 Test 2023-03-10 16:34:18 +00:00
David Robertson
c246e64701 Rename param 2023-03-10 13:43:40 +00:00
David Robertson
30965867ac Changelog 2023-03-10 12:10:55 +00:00
David Robertson
85a98b1023 Add explicit option for partial state rooms 2023-03-10 12:08:17 +00:00
David Robertson
909eecb348 Require explicit boolean options from callers 2023-03-10 12:00:48 +00:00
David Robertson
ab4eea570f Track a set of strings, not EventBases 2023-03-10 11:56:41 +00:00
David Robertson
fa5ca2edea Separate decision from action 2023-03-10 11:55:40 +00:00
David Robertson
21d1fc8cf2 Flip logic and provide better name 2023-03-10 11:55:40 +00:00
David Robertson
cda97ccdb9 Tweak docstring and type hint 2023-03-10 11:55:39 +00:00
9 changed files with 327 additions and 35 deletions

1
changelog.d/15240.misc Normal file
View File

@@ -0,0 +1 @@
Refactor `filter_events_for_server`.

1
changelog.d/15241.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a rare bug introduced in Synapse 1.73 where events could remain unsent to other homeservers after a faster-join to a room.

View File

@@ -547,6 +547,8 @@ class PerDestinationQueue:
self._server_name,
new_pdus,
redact=False,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
# If we've filtered out all the extremities, fall back to

View File

@@ -392,7 +392,7 @@ class FederationHandler:
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# We unset `filter_out_erased_senders` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self._storage_controllers,
@@ -400,7 +400,8 @@ class FederationHandler:
self.server_name,
events_to_check,
redact=False,
check_history_visibility_only=True,
filter_out_erased_senders=False,
filter_out_remote_partial_state_events=False,
)
if filtered_extremities:
extremities_to_request.append(bp.event_id)
@@ -1331,7 +1332,13 @@ class FederationHandler:
)
events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, events
self._storage_controllers,
origin,
self.server_name,
events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
return events
@@ -1362,7 +1369,13 @@ class FederationHandler:
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, [event]
self._storage_controllers,
origin,
self.server_name,
[event],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
event = events[0]
return event
@@ -1390,7 +1403,13 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, missing_events
self._storage_controllers,
origin,
self.server_name,
missing_events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
return missing_events

View File

@@ -14,7 +14,17 @@
# limitations under the License.
import logging
from enum import Enum, auto
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
from typing import (
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from typing_extensions import Final
@@ -565,29 +575,43 @@ async def filter_events_for_server(
storage: StorageControllers,
target_server_name: str,
local_server_name: str,
events: List[EventBase],
redact: bool = True,
check_history_visibility_only: bool = False,
events: Sequence[EventBase],
*,
redact: bool,
filter_out_erased_senders: bool,
filter_out_remote_partial_state_events: bool,
) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to
"""Filter a list of events based on whether the target server is allowed to
see them.
For a fully stated room, the target server is allowed to see an event E if:
- the state at E has world readable or shared history vis, OR
- the state at E says that the target server is in the room.
For a partially stated room, the target server is allowed to see E if:
- E was created by this homeserver, AND:
- the partial state at E has world readable or shared history vis, OR
- the partial state at E says that the target server is in the room.
TODO: state before or state after?
Args:
storage
server_name
target_server_name
local_server_name
events
redact: Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been
redact: Controls what to do with events which have been filtered out.
If True, include their redacted forms; if False, omit them entirely.
filter_out_erased_senders: If true, also filter out events whose sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.
filter_out_remote_partial_state_events: If True, also filter out events in
partial state rooms created by other homeservers.
Returns
The filtered events.
"""
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
@@ -616,7 +640,7 @@ async def filter_events_for_server(
# server has no users in the room: redact
return False
if not check_history_visibility_only:
if filter_out_erased_senders:
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
# We don't want to check whether users are erased, which is equivalent
@@ -631,44 +655,52 @@ async def filter_events_for_server(
# otherwise a room could be fully joined after we retrieve those, which would then bypass
# this check but would base the filtering on an outdated view of the membership events.
partial_state_invisible_events = set()
if not check_history_visibility_only:
partial_state_invisible_event_ids: Set[str] = set()
if filter_out_remote_partial_state_events:
maybe_visible_events: List[EventBase] = []
for e in events:
sender_domain = get_domain_from_id(e.sender)
if (
sender_domain != local_server_name
and await storage.main.is_partial_state_room(e.room_id)
):
partial_state_invisible_events.add(e)
partial_state_invisible_event_ids.add(e.event_id)
else:
maybe_visible_events.append(e)
else:
maybe_visible_events = list(events)
# Let's check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room).
event_to_history_vis = await _event_to_history_vis(storage, events)
event_to_history_vis = await _event_to_history_vis(storage, maybe_visible_events)
# for any with restricted vis, we also need the memberships
event_to_memberships = await _event_to_memberships(
storage,
[
e
for e in events
for e in maybe_visible_events
if event_to_history_vis[e.event_id]
not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
],
target_server_name,
)
to_return = []
for e in events:
def include_event_in_output(e: EventBase) -> bool:
if e.event_id in partial_state_invisible_event_ids:
return False
erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible(
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
)
if e in partial_state_invisible_events:
visible = False
return visible and not erased
if visible and not erased:
to_return = []
for e in events:
if include_event_in_output(e):
to_return.append(e)
elif redact:
to_return.append(prune_event(e))

View File

@@ -1,4 +1,5 @@
from typing import Callable, List, Optional, Tuple
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -500,3 +501,177 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
self.assertFalse(per_dest_queue._catching_up)
def test_catch_up_is_not_blocked_by_remote_event_in_partial_state_room(
self,
) -> None:
"""Detects (part of?) https://github.com/matrix-org/synapse/issues/15220."""
# ARRANGE:
# - a local user (u1)
# - a room which contains u1 and two remote users, @u2:host2 and @u3:other
# - events in that room such that
# - history visibility is restricted
# - u1 sent message events e1 and e2
# - afterwards, u3 sent a remote event e3
# - catchup to begin for host2; last successfully sent event was e1
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
self.register_user("u1", "you the one")
u1_token = self.login("u1", "you the one")
room = self.helper.create_room_as("u1", tok=u1_token)
self.helper.send_state(
room_id=room,
event_type="m.room.history_visibility",
body={"history_visibility": "joined"},
tok=u1_token,
)
self.get_success(
event_injection.inject_member_event(self.hs, room, "@u2:host2", "join")
)
self.get_success(
event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
)
# create some events
event_id_1 = self.helper.send(room, "hello", tok=u1_token)["event_id"]
event_id_2 = self.helper.send(room, "world", tok=u1_token)["event_id"]
# pretend that u3 changes their displayname
event_id_3 = self.get_success(
event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
).event_id
# destination_rooms should already be populated, but let us pretend that we already
# sent (successfully) up to and including event id 1
event_1 = self.get_success(self.hs.get_datastores().main.get_event(event_id_1))
assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
)
)
# also fetch event 2 so we can compare its stream ordering to the sender's
# last_successful_stream_ordering later
event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2))
# Mock event 3 as having partial state
self.get_success(
event_injection.mark_event_as_partial_state(self.hs, event_id_3, room)
)
# Fail the test if we block on full state for event 3.
async def mock_await_full_state(event_ids: Collection[str]) -> None:
if event_id_3 in event_ids:
raise AssertionError("Tried to await full state for event_id_3")
# ACT
with mock.patch.object(
self.hs.get_storage_controllers().state._partial_state_events_tracker,
"await_full_state",
mock_await_full_state,
):
self.get_success(per_dest_queue._catch_up_transmission_loop())
# ASSERT
# We should have:
# - not sent event 3: it's not ours, and the room is partial stated
# - fallen back to sending event 2: it's the most recent event in the room
# we tried to send to host2
# - completed catch-up
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_id_2)
self.assertFalse(per_dest_queue._catching_up)
self.assertEqual(
per_dest_queue._last_successful_stream_ordering,
event_2.internal_metadata.stream_ordering,
)
def test_catch_up_is_not_blocked_by_local_event_in_partial_state_room(
self,
) -> None:
"""Detects (part of?) https://github.com/matrix-org/synapse/issues/15220."""
# ARRANGE:
# - a local user (u1)
# - a room which initially contains u1 and u3:other
# - events in that room such that
# - history visibility is restricted
# - e1: message from u1
# - e2: remote user u2:host2 joins
# - e3: message from u1
# - e4: message from u1
# - e5: u1 kicks user u2:host2
# - catchup to begin for host2, after having last successfully sent them e3
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
self.register_user("u1", "you the one")
u1_token = self.login("u1", "you the one")
room = self.helper.create_room_as("u1", tok=u1_token)
self.helper.send_state(
room_id=room,
event_type="m.room.history_visibility",
body={"history_visibility": "joined"},
tok=u1_token,
)
self.get_success(
event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
)
# create some events
event_id_1 = self.helper.send(room, "Hello", tok=u1_token)["event_id"]
event_id_2 = self.get_success(
event_injection.inject_member_event(self.hs, room, "@u2:host2", "join")
).event_id
event_id_3 = self.helper.send(room, "Nicholas,", tok=u1_token)["event_id"]
event_id_4 = self.helper.send(room, "how's the hand?", tok=u1_token)["event_id"]
event_id_5 = self.get_success(
event_injection.inject_member_event(
self.hs,
room,
sender="@u1:test",
target="@u3:host2",
membership="invite",
)
).event_id
# Mock all numbered events as having partial state
numbered_events_ids = (
event_id_1,
event_id_2,
event_id_3,
event_id_4,
event_id_5,
)
for event_id in numbered_events_ids:
self.get_success(
event_injection.mark_event_as_partial_state(self.hs, event_id, room)
)
# Fail the test if we block on full state for any of these events.
async def mock_await_full_state(event_ids: Collection[str]) -> None:
for i, e in enumerate(numbered_events_ids, start=1):
if e in event_ids:
raise AssertionError(f"Tried to await full state for event_id_{i}")
# Pretend we have sent e3 to host2.
event_3 = self.get_success(self.hs.get_datastores().main.get_event(event_id_3))
assert event_3.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2",
event_3.internal_metadata.stream_ordering,
)
)
# ACT
with mock.patch.object(
self.hs.get_storage_controllers().state._partial_state_events_tracker,
"await_full_state",
mock_await_full_state,
):
self.get_success(per_dest_queue._catch_up_transmission_loop())
# ASSERT
# TODO summarise this
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_id_5)
self.assertFalse(per_dest_queue._catching_up)

View File

@@ -279,7 +279,7 @@ class RestHelper:
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
expect_additional_fields: Optional[dict] = None,
) -> None:
) -> JsonDict:
"""
Send a membership state event into a room.
@@ -353,6 +353,7 @@ class RestHelper:
)
self.auth_user_id = temp_id
return channel.json_body
def send(
self,

View File

@@ -102,3 +102,34 @@ async def create_event(
context = await unpersisted_context.persist(event)
return event, context
async def mark_event_as_partial_state(
hs: synapse.server.HomeServer,
event_id: str,
room_id: str,
) -> None:
"""
(Falsely) mark an event as having partial state.
Naughty, but occasionally useful when checking that partial state doesn't
block something from happening.
If the event already has partial state, this insert will fail (event_id is unique
in this table).
"""
store = hs.get_datastores().main
await store.db_pool.simple_upsert(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"room_id": room_id},
)
await store.db_pool.simple_insert(
table="partial_state_events",
values={
"room_id": room_id,
"event_id": event_id,
},
)

View File

@@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", "hs", events_to_filter
self._storage_controllers,
"test_server",
"hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)
@@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", "hs", [outlier]
self._storage_controllers,
"remote_hs",
"hs",
[outlier],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
),
[outlier],
@@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
self._storage_controllers,
"remote_hs",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# be redacted)
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "other_server", "local_hs", [outlier, evt]
self._storage_controllers,
"other_server",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(filtered[0], outlier)
@@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", "local_hs", events_to_filter
self._storage_controllers,
"test_server",
"local_hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)