Compare commits

...

13 Commits

Author SHA1 Message Date
David Robertson
36dca8bba5 Canonicaljson is supposed to have types now 2022-10-02 23:36:01 +01:00
David Robertson
a952172389 Also disallow-untyped-defs lol 2022-10-02 23:15:45 +01:00
David Robertson
85febbd3ac Fix a surprisingly tricky mypy error
Previous commit makes `QueryList`'s values be Optional[str]` instead of str.

Before the change
- `device_id` is deduced to be an `Optional[str]` as it comes from
  iterating over `query_list`.
- when we use `device_id` to mark deleted devices as holding `None` in
  `result`, mypy complaints that we are using an `Optional[str]` to
  lookup something in a Dict whose keys are `str`.

Fix this in two steps.

1. Avoid name reuse.
2. Don't store `None` in the initial version of `deleted_devices`.
2022-10-02 00:35:06 +01:00
David Robertson
a97c284e62 Fix incorrect annotations 2022-10-02 00:32:31 +01:00
David Robertson
684f336c43 Suppress false positives from mypy 2022-10-01 23:57:19 +01:00
David Robertson
f8fd5ddefe Fix incorrect annotation in the module API 2022-10-01 23:50:03 +01:00
David Robertson
d75da0e392 Fix errors related to ambiguous db_autocommit 2022-10-01 23:50:03 +01:00
David Robertson
9c34f6eaee Annotate runInteraction 2022-10-01 22:23:53 +01:00
David Robertson
9a35283393 Annotate advanced function 2022-10-01 22:22:42 +01:00
David Robertson
e67cd89e7b Reorder args so *args, **kwargs comes at the end 2022-10-01 22:20:14 +01:00
David Robertson
c7ee636762 use _advanced function where isolation_level and db_autocommit is in use 2022-10-01 21:49:12 +01:00
David Robertson
a71ec5e67c Separate runInteraction into simple and advanced 2022-10-01 21:42:48 +01:00
David Robertson
a591a3f778 inner_func takes a twisted Connection object 2022-09-30 22:51:22 +01:00
15 changed files with 135 additions and 85 deletions

View File

@@ -97,9 +97,6 @@ disallow_untyped_defs = False
[mypy-synapse.server]
disallow_untyped_defs = False
[mypy-synapse.storage.database]
disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False
@@ -139,9 +136,6 @@ disallow_untyped_defs = True
[mypy-authlib.*]
ignore_missing_imports = True
[mypy-canonicaljson]
ignore_missing_imports = True
[mypy-ijson.*]
ignore_missing_imports = True

View File

@@ -30,7 +30,7 @@ from typing import (
import attr
import jinja2
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.web.resource import Resource
@@ -813,7 +813,7 @@ class ModuleApi:
def run_db_interaction(
self,
desc: str,
func: Callable[P, T],
func: Callable[Concatenate[LoggingTransaction, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":
@@ -831,9 +831,8 @@ class ModuleApi:
Returns:
Deferred[object]: result of func
"""
# type-ignore: See https://github.com/python/mypy/issues/8862
return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
)
def register_cached_function(self, cached_func: CachedFunction) -> None:

View File

@@ -34,6 +34,7 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
@@ -801,11 +802,22 @@ class DatabasePool:
async def runInteraction(
self,
desc: str,
func: Callable[..., R],
*args: Any,
db_autocommit: bool = False,
isolation_level: Optional[int] = None,
**kwargs: Any,
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
return await self.runInteraction_advanced(
desc, False, None, func, *args, **kwargs
)
async def runInteraction_advanced(
self,
desc: str,
db_autocommit: bool,
isolation_level: Optional[int],
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Starts a transaction on the database and runs a given function
@@ -916,7 +928,7 @@ class DatabasePool:
start_time = monotonic_time()
def inner_func(conn, *args, **kwargs):
def inner_func(conn: adbapi.Connection, *args: P.args, **kwargs: P.kwargs) -> R:
# We shouldn't be in a transaction. If we are then something
# somewhere hasn't committed after doing work. (This is likely only
# possible during startup, as `run*` will ensure changes are
@@ -1009,7 +1021,7 @@ class DatabasePool:
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> R:
) -> Union[R, List[Tuple[Any, ...]]]:
"""Runs a single query for a result set.
Args:
@@ -1022,7 +1034,7 @@ class DatabasePool:
The result of decoder(results)
"""
def interaction(txn):
def interaction(txn: LoggingTransaction) -> Union[R, List[Tuple[Any, ...]]]:
txn.execute(query, args)
if decoder:
return decoder(txn)
@@ -1202,15 +1214,16 @@ class DatabasePool:
# We can autocommit if it is safe to upsert
autocommit = table not in self._unsafe_to_upsert_tables
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
autocommit,
None,
self.simple_upsert_txn,
table,
keyvalues,
values,
insertion_values,
lock=lock,
db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
attempts += 1
@@ -1425,8 +1438,10 @@ class DatabasePool:
# We can autocommit if it safe to upsert
autocommit = table not in self._unsafe_to_upsert_tables
await self.runInteraction(
await self.runInteraction_advanced(
desc,
autocommit,
None,
self.simple_upsert_many_txn,
table,
key_names,
@@ -1434,7 +1449,6 @@ class DatabasePool:
value_names,
value_values,
lock=lock,
db_autocommit=autocommit,
)
def simple_upsert_many_txn(
@@ -1611,14 +1625,15 @@ class DatabasePool:
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_select_one_txn,
table,
keyvalues,
retcols,
allow_none,
db_autocommit=True,
)
@overload
@@ -1662,14 +1677,19 @@ class DatabasePool:
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_select_one_onecol_txn,
table,
keyvalues,
retcol,
allow_none=allow_none,
db_autocommit=True,
# Type ignore suppresses a mypy bug:
# Argument "allow_none" to "runInteraction_advanced" of "DatabasePool"
# has incompatible type "bool"; expected "Literal[False]" [arg-type]
# I think mypy is confused by the overloads of simple_select_one_onecol_txn.
allow_none=allow_none, # type:ignore[arg-type]
)
@overload
@@ -1721,7 +1741,7 @@ class DatabasePool:
def simple_select_onecol_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
keyvalues: Optional[Dict[str, Any]],
retcol: str,
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
@@ -1753,13 +1773,14 @@ class DatabasePool:
Returns:
Results in a list
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_select_onecol_txn,
table,
keyvalues,
retcol,
db_autocommit=True,
)
async def simple_select_list(
@@ -1783,13 +1804,14 @@ class DatabasePool:
Returns:
A list of dictionaries.
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_select_list_txn,
table,
keyvalues,
retcols,
db_autocommit=True,
)
@classmethod
@@ -1853,15 +1875,16 @@ class DatabasePool:
results: List[Dict[str, Any]] = []
for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
rows = await self.runInteraction_advanced(
desc,
True,
None,
self.simple_select_many_txn,
table,
column,
chunk,
keyvalues,
retcols,
db_autocommit=True,
)
results.extend(rows)
@@ -1949,7 +1972,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
value_values: Collection[Iterable[Any]],
desc: str,
) -> None:
"""
@@ -2039,13 +2062,14 @@ class DatabasePool:
updatevalues: dict giving column names and values to update
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
await self.runInteraction_advanced(
desc,
True,
None,
self.simple_update_one_txn,
table,
keyvalues,
updatevalues,
db_autocommit=True,
)
@classmethod
@@ -2104,12 +2128,13 @@ class DatabasePool:
keyvalues: dict of column names and values to select the row with
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
await self.runInteraction_advanced(
desc,
True,
None,
self.simple_delete_one_txn,
table,
keyvalues,
db_autocommit=True,
)
@staticmethod
@@ -2149,8 +2174,8 @@ class DatabasePool:
Returns:
The number of deleted rows.
"""
return await self.runInteraction(
desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
return await self.runInteraction_advanced(
desc, True, None, self.simple_delete_txn, table, keyvalues
)
@staticmethod
@@ -2199,14 +2224,15 @@ class DatabasePool:
Returns:
Number rows deleted
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_delete_many_txn,
table,
column,
iterable,
keyvalues,
db_autocommit=True,
)
@staticmethod
@@ -2392,14 +2418,15 @@ class DatabasePool:
A list of dictionaries or None.
"""
return await self.runInteraction(
return await self.runInteraction_advanced(
desc,
True,
None,
self.simple_search_list_txn,
table,
term,
col,
retcols,
db_autocommit=True,
)
@classmethod

View File

@@ -278,7 +278,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _get_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
query_list: Collection[Tuple[str, str]],
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@@ -295,15 +295,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
include_deleted_devices = False
if include_deleted_devices:
deleted_devices = set(query_list)
deleted_devices = {
(user_id, device_id)
for user_id, device_id in query_list
if device_id is not None
}
for (user_id, device_id) in query_list:
for (queried_user_id, queried_device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
query_params.append(queried_user_id)
if device_id is not None:
if queried_device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
query_params.append(queried_device_id)
query_clauses.append(query_clause)
@@ -322,10 +326,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn.execute(sql, query_params)
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
fetched_user_id: str
fetched_device_id: str
display_name: Optional[str]
key_json: Optional[str]
for (fetched_user_id, fetched_device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
deleted_devices.remove((fetched_user_id, fetched_device_id))
result.setdefault(fetched_user_id, {})[
fetched_device_id
] = DeviceKeyLookupResult(
display_name, db_to_json(key_json) if key_json else None
)
@@ -1082,13 +1092,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
claim_row = await self.db_pool.runInteraction(
claim_row = await self.db_pool.runInteraction_advanced(
"claim_e2e_one_time_keys",
db_autocommit,
None,
_claim_e2e_one_time_key,
user_id,
device_id,
algorithm,
db_autocommit=db_autocommit,
)
if claim_row:
device_results = results.setdefault(user_id, {}).setdefault(

View File

@@ -1501,13 +1501,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_id: The event that failed to be fetched or processed
cause: The error message or reason that we failed to pull the event
"""
await self.db_pool.runInteraction(
await self.db_pool.runInteraction_advanced(
"record_event_failed_pull_attempt",
True, # Safe to autocommit as it's a single upsert
None,
self._record_event_failed_pull_attempt_upsert_txn,
room_id,
event_id,
cause,
db_autocommit=True, # Safe as it's a single upsert
)
def _record_event_failed_pull_attempt_upsert_txn(
@@ -1689,10 +1690,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return row[0]
return await self.db_pool.runInteraction(
return await self.db_pool.runInteraction_advanced(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
isolation_level=None,
func=_remove_received_event_from_staging_txn,
)
else:

View File

@@ -162,12 +162,11 @@ class LockStore(SQLBaseStore):
# We only acquired the lock if we inserted or updated the table.
return bool(txn.rowcount)
did_lock = await self.db_pool.runInteraction(
did_lock = await self.db_pool.runInteraction_advanced(
"try_acquire_lock",
_try_acquire_lock_txn,
# We can autocommit here as we're executing a single query, this
# will avoid serialization errors.
db_autocommit=True,
isolation_level=None,
func=_try_acquire_lock_txn,
)
if not did_lock:
return None

View File

@@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
cast,
)
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -126,7 +136,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
def _update_presence_txn(
self,
txn: LoggingTransaction,
stream_orderings: List[int],
stream_orderings: Sequence[int],
presence_states: List[UserPresenceState],
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):

View File

@@ -325,11 +325,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# We then run the same purge a second time without this isolation level to
# purge any of those rows which were added during the first.
state_groups_to_delete = await self.db_pool.runInteraction(
state_groups_to_delete = await self.db_pool.runInteraction_advanced(
"purge_room",
self._purge_room_txn,
room_id=room_id,
False,
isolation_level=IsolationLevel.READ_COMMITTED,
func=self._purge_room_txn,
room_id=room_id,
)
state_groups_to_delete.extend(

View File

@@ -380,8 +380,8 @@ class PushRuleStore(PushRulesWorkerStore):
priority_class: int,
conditions_json: str,
actions_json: str,
before: str,
after: str,
before: Optional[str],
after: Optional[str],
) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.

View File

@@ -22,6 +22,7 @@ from typing import (
Iterator,
List,
Optional,
Sequence,
Tuple,
cast,
)
@@ -650,7 +651,9 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
def delete_pushers_txn(
txn: LoggingTransaction, stream_ids: Sequence[int]
) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)

View File

@@ -777,8 +777,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
event_ts = await self.db_pool.runInteraction_advanced(
"insert_linearized_receipt",
False,
IsolationLevel.READ_COMMITTED,
self._insert_linearized_receipt_txn,
room_id,
receipt_type,
@@ -787,10 +789,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
thread_id,
data,
stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with
# greater stream order, and checking the very latest data at select time is better
# than the data at transaction start time.
isolation_level=IsolationLevel.READ_COMMITTED,
)
# If the receipt was older than the currently persisted one, nothing to do.

View File

@@ -420,14 +420,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
txn: LoggingTransaction,
user_id: str,
membership_list: List[str],
membership_list: Collection[str],
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Args:
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
membership_list: A collection of synapse.api.constants.Membership
values which the user must be in.
Returns:

View File

@@ -221,14 +221,15 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
retry_interval: how long until next retry in ms
"""
await self.db_pool.runInteraction(
await self.db_pool.runInteraction_advanced(
"set_destination_retry_timings",
True,
None,
self._set_destination_retry_timings_native,
destination,
failure_ts,
retry_last_ts,
retry_interval,
db_autocommit=True, # Safe as it's a single upsert
)
def _set_destination_retry_timings_native(

View File

@@ -547,6 +547,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
if prev_group is not None:
# This assertion is for mypy's benefit and is checked above.
assert delta_ids is not None
state_group = await self.db_pool.runInteraction(
"store_state_group.insert_delta_group",
insert_delta_group_txn,

View File

@@ -782,11 +782,12 @@ class _MultiWriterCtxManager:
async def __aenter__(self) -> Union[int, List[int]]:
# It's safe to run this in autocommit mode as fetching values from a
# sequence ignores transaction semantics anyway.
self.stream_ids = await self.id_gen._db.runInteraction(
self.stream_ids = await self.id_gen._db.runInteraction_advanced(
"_load_next_mult_id",
True,
None,
self.id_gen._load_next_mult_id_txn,
self.multiple_ids or 1,
db_autocommit=True,
)
if self.multiple_ids is None:
@@ -818,10 +819,11 @@ class _MultiWriterCtxManager:
# for. If we don't do this then we'll often hit serialization errors due
# to the fact we default to REPEATABLE READ isolation levels.
if self.id_gen._writers:
await self.id_gen._db.runInteraction(
await self.id_gen._db.runInteraction_advanced(
"MultiWriterIdGenerator._update_table",
self.id_gen._update_stream_positions_table_txn,
db_autocommit=True,
isolation_level=None,
func=self.id_gen._update_stream_positions_table_txn,
)
return False