Compare commits

...

13 Commits

Author SHA1 Message Date
Olivier Wilkinson (reivilibre)
2ccbea4107 Antilint 2021-08-05 18:03:12 +01:00
Olivier Wilkinson (reivilibre)
0736840f1e set timeout to zero 2021-08-05 16:39:13 +01:00
Olivier Wilkinson (reivilibre)
4bb7bc8ffd Fix up type checking and ENABLE THE FILE :) 2021-08-05 16:28:10 +01:00
Olivier Wilkinson (reivilibre)
990f3b5003 Merge remote-tracking branch 'origin/develop' into rei/gsgfg 2021-08-05 15:51:54 +01:00
Olivier Wilkinson (reivilibre)
bbb0473cd0 Fix type error 2021-08-05 15:43:39 +01:00
Olivier Wilkinson (reivilibre)
dcb6fc5023 Predicate it on TYPE_CHECKING (not that it improves things) 2021-08-05 15:35:25 +01:00
Olivier Wilkinson (reivilibre)
215019cd66 Fix-ups 2021-08-05 15:29:31 +01:00
Olivier Wilkinson (reivilibre)
2f7eeefa4b antilint 2021-08-04 15:07:19 +01:00
Olivier Wilkinson (reivilibre)
5fa9110c24 Make StateFilter frozen 2021-08-04 15:06:06 +01:00
Olivier Wilkinson (reivilibre)
b09de10dff Remove _get_state_groups_from_groups_txn 2021-08-04 14:56:30 +01:00
Olivier Wilkinson (reivilibre)
ae9d273534 Newsfile 2021-08-02 16:56:29 +01:00
Olivier Wilkinson (reivilibre)
cefcab7734 Use a ResponseCache and make keys hashable 2021-08-02 16:56:28 +01:00
Olivier Wilkinson (reivilibre)
507cafc2c3 Make a single-group transaction function 2021-08-02 16:56:28 +01:00
6 changed files with 155 additions and 94 deletions

1
changelog.d/10510.misc Normal file
View File

@@ -0,0 +1 @@
Make _get_state_groups_from_groups use caching (for each individual group to query).

View File

@@ -88,6 +88,7 @@ files =
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py, tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py, tests/rest/client/v2_alpha/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import typing
from typing import Optional from typing import Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@@ -20,6 +21,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
if typing.TYPE_CHECKING:
from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,12 +76,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count return count
def _get_state_groups_from_groups_txn( def _get_state_groups_from_group_txn(
self, txn, groups, state_filter: Optional[StateFilter] = None self, txn, group: int, state_filter: Optional[StateFilter] = None
): ) -> "StateMap[str]":
state_filter = state_filter or StateFilter.all() state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups} result = {}
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -116,64 +120,62 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
ORDER BY type, state_key, state_group DESC ORDER BY type, state_key, state_group DESC
""" """
for group in groups: args = [group]
args = [group] args.extend(where_args)
args.extend(where_args)
txn.execute(sql % (where_clause,), args) txn.execute(sql % (where_clause,), args)
for row in txn: for row in txn:
typ, state_key, event_id = row typ, state_key, event_id = row
key = (typ, state_key) key = (typ, state_key)
results[group][key] = event_id result[key] = event_id
else: else:
max_entries_returned = state_filter.max_entries_returned() max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: next_group = group
next_group = group
while next_group: while next_group:
# We did this before by getting the list of group ids, and # We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for # then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow # each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until # without the right indices (which we can't add until
# after we finish deduping state, which requires this func) # after we finish deduping state, which requires this func)
args = [next_group] args = [next_group]
args.extend(where_args) args.extend(where_args)
txn.execute( txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state" "SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause, " WHERE state_group = ? " + where_clause,
args, args,
) )
results[group].update( result.update(
((typ, state_key), event_id) ((typ, state_key), event_id)
for typ, state_key, event_id in txn for typ, state_key, event_id in txn
if (typ, state_key) not in results[group] if (typ, state_key) not in result
) )
# If the number of entries in the (type,state_key)->event_id dict # If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching # matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk # for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained # further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive # wildcards (i.e. Nones) in which case we have to do an exhaustive
# search # search
if ( if (
max_entries_returned is not None max_entries_returned is not None
and len(results[group]) == max_entries_returned and len(result) == max_entries_returned
): ):
break break
next_group = self.db_pool.simple_select_one_onecol_txn( next_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
keyvalues={"state_group": next_group}, keyvalues={"state_group": next_group},
retcol="prev_state_group", retcol="prev_state_group",
allow_none=True, allow_none=True,
) )
return results return result
class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
@@ -261,14 +263,10 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades. # otherwise read performance degrades.
continue continue
prev_state = self._get_state_groups_from_groups_txn( prev_state = self._get_state_groups_from_group_txn(txn, prev_group)
txn, [prev_group]
)
prev_state = prev_state[prev_group] prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn( curr_state = self._get_state_groups_from_group_txn(txn, state_group)
txn, [state_group]
)
curr_state = curr_state[state_group] curr_state = curr_state[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()): if not set(prev_state.keys()) - set(curr_state.keys()):

View File

@@ -26,6 +26,7 @@ from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.response_cache import ResponseCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -91,6 +92,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000, 500000,
) )
self._state_group_from_group_cache = ResponseCache(
self.hs.get_clock(),
# REVIEW: why do the other 2 have asterisks? should this one too?
"*stateGroupFromGroupCache*",
# we're only using this cache to track in-flight requests;
# the results are added to another cache once complete.
timeout_ms=0,
)
def get_max_state_group_txn(txn: Cursor): def get_max_state_group_txn(txn: Cursor):
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] return txn.fetchone()[0]
@@ -156,19 +166,39 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
""" """
results = {} results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for group in groups:
for chunk in chunks: results[group] = await self._get_state_groups_from_group(
res = await self.db_pool.runInteraction( group, state_filter
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
state_filter,
) )
results.update(res)
return results return results
def _get_state_for_group_using_cache(self, cache, group, state_filter): async def _get_state_groups_from_group(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""Returns the state groups for a given group from the
database, filtering on types of state events.
Args:
group: state group ID to query
state_filter: The state filter used to fetch state
from the database.
Returns:
state map
"""
return await self._state_group_from_group_cache.wrap(
(group, state_filter),
self.db_pool.runInteraction,
"_get_state_groups_from_group",
self._get_state_groups_from_group_txn,
group,
state_filter,
)
def _get_state_for_group_using_cache(
self, cache: DictionaryCache, group: int, state_filter: StateFilter
) -> Tuple[StateMap, bool]:
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Args: Args:
@@ -546,7 +576,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions. # groups to non delta versions.
for sg in remaining_state_groups: for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg) logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = self._get_state_groups_from_group_txn(txn, sg)
curr_state = curr_state[sg] curr_state = curr_state[sg]
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(

View File

@@ -25,12 +25,15 @@ from typing import (
) )
import attr import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@attr.s(slots=True) @attr.s(slots=True, frozen=True)
class StateFilter: class StateFilter:
"""A filter used when querying for state. """A filter used when querying for state.
@@ -53,14 +56,20 @@ class StateFilter:
appear in `types`. appear in `types`.
""" """
types = attr.ib(type=Dict[str, Optional[Set[str]]]) types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool) include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None} # REVIEW: yucky. any better way?
# Work around this class being frozen.
object.__setattr__(
self,
"types",
frozendict({k: v for k, v in self.types.items() if v is not None}),
)
@staticmethod @staticmethod
def all() -> "StateFilter": def all() -> "StateFilter":
@@ -69,7 +78,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types=frozendict(), include_others=True)
@staticmethod @staticmethod
def none() -> "StateFilter": def none() -> "StateFilter":
@@ -78,7 +87,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types=frozendict(), include_others=False)
@staticmethod @staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +112,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict) return StateFilter(
types=frozendict(
(k, frozenset(v) if v is not None else None)
for k, v in type_dict.items()
)
)
@staticmethod @staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +130,10 @@ class StateFilter:
Returns: Returns:
The new state filter The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(
types=frozendict({EventTypes.Member: frozenset(members)}),
include_others=True,
)
def return_expanded(self) -> "StateFilter": def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +190,7 @@ class StateFilter:
# We want to return all non-members, but only particular # We want to return all non-members, but only particular
# memberships # memberships
return StateFilter( return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]}, types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True, include_others=True,
) )
@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None: if state_keys is None:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter({EventTypes.Member: state_keys}) member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others: elif self.include_others:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter.none() member_filter = StateFilter.none()
non_member_filter = StateFilter( non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member}, types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others, include_others=self.include_others,
) )

View File

@@ -14,6 +14,8 @@
import logging import logging
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types=frozendict(
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
),
include_others=True, include_others=True,
), ),
) )
@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}),
include_others=True,
), ),
) )
) )
@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )