Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Get the cache storage passing mypy #11354

Closed
wants to merge 17 commits into from
Closed
1 change: 1 addition & 0 deletions changelog.d/11354.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/e2e_room_keys.py
|synapse/storage/databases/main/end_to_end_keys.py
Expand Down
8 changes: 6 additions & 2 deletions scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
Expand All @@ -57,7 +58,6 @@ from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.user_directory import (
UserDirectoryBackgroundUpdateStore,
)
Expand Down Expand Up @@ -179,10 +179,10 @@ class Store(
MainStateBackgroundUpdateStore,
UserDirectoryBackgroundUpdateStore,
EndToEndKeyBackgroundStore,
StatsStore,
PusherWorkerStore,
PresenceBackgroundUpdateStore,
GroupServerWorkerStore,
SlavedEventStore,
):
def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
Expand Down Expand Up @@ -229,6 +229,10 @@ class MockHomeserver:
def get_instance_name(self):
return "master"

def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
return False


class Porter(object):
def __init__(self, **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from synapse.config.logger import setup_logging
from synapse.events import EventBase
from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
Expand Down Expand Up @@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedClientIpStore,
BaseSlavedStore,
RoomWorkerStore,
):
pass
Expand Down
6 changes: 0 additions & 6 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,12 @@
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
Expand Down Expand Up @@ -114,7 +112,6 @@
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
Expand Down Expand Up @@ -223,7 +220,6 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
StatsStore,
UIAuthWorkerStore,
EndToEndRoomKeyStore,
PresenceStore,
Expand All @@ -236,7 +232,6 @@ class GenericWorkerSlavedStore(
SlavedPusherStore,
CensorEventsStore,
ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore,
RoomWorkerStore,
DirectoryStore,
Expand All @@ -252,7 +247,6 @@ class GenericWorkerSlavedStore(
TransactionWorkerStore,
LockStore,
SessionStore,
BaseSlavedStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
Expand Down
6 changes: 2 additions & 4 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
Expand All @@ -30,9 +30,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[
MultiWriterIdGenerator
] = MultiWriterIdGenerator(
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
stream_name="caches",
Expand Down
16 changes: 0 additions & 16 deletions synapse/replication/slave/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,7 @@
from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache

Expand All @@ -47,15 +38,8 @@


class SlavedEventStore(
EventFederationWorkerStore,
RoomMemberWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
StateGroupWorkerStore,
EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .events import SlavedEventStore


class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def commit(self) -> None:
def rollback(self) -> None:
self.conn.rollback()

def __enter__(self) -> "Connection":
def __enter__(self) -> "LoggingDatabaseConnection":
self.conn.__enter__()
return self

Expand Down
3 changes: 0 additions & 3 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionWorkerStore
Expand Down Expand Up @@ -119,7 +118,6 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
CensorEventsStore,
UIAuthStore,
Expand Down Expand Up @@ -154,7 +152,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
db_conn, "local_group_updates", "stream_id"
)

self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
# missing updates over restarts, as we'll not have anything in our
Expand Down
83 changes: 66 additions & 17 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import itertools
import logging
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
Expand All @@ -24,9 +24,22 @@
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand All @@ -39,16 +52,35 @@
# based on the current state when notifying workers over replication.
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"

# Corresponds to the (cache_func, keys, invalidation_ts) db columns.
_CacheData = Tuple[str, Optional[List[str]], Optional[int]]


class CacheInvalidationWorkerStore(
EventFederationWorkerStore,
RelationsWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
StateDeltasStore,
RoomMemberWorkerStore,
):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
_cache_id_gen: Optional[MultiWriterIdGenerator]

class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()

async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
"""Get updates for caches replication stream.

Args:
Expand All @@ -73,7 +105,9 @@ async def get_all_updated_caches(
if last_id == current_id:
return [], current_id, False

def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
Expand All @@ -85,7 +119,13 @@ def get_all_updated_caches_txn(txn):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
updates = [(row[0], row[1:]) for row in txn]
updates: List[Tuple[int, _CacheData]] = []
row: Tuple[int, str, Optional[List[str]], Optional[int]]
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
updates.append((row[0], row[1:]))
clokep marked this conversation as resolved.
Show resolved Hide resolved
limited = False
upto_token = current_id
if len(updates) >= limit:
Expand Down Expand Up @@ -192,7 +232,9 @@ def _invalidate_caches_for_event(
self.get_aggregation_groups_for_event.invalidate((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))

async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[str, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.

Expand All @@ -212,7 +254,9 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ..
keys,
)

def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction, keys: Iterable[str]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.

Expand All @@ -223,24 +267,28 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""

txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)

def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
def _invalidate_state_caches_and_stream(
self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
) -> None:
"""Special case invalidation of caches based on current state.

We special case this so that we can batch the cache invalidations into a
single replication poke.

Args:
txn
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have changed
room_id: Room where state changed
members_changed: The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)

Expand All @@ -262,8 +310,8 @@ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
)

def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[str]]
) -> None:
"""Notifies replication that given cache has been invalidated.

Note that this does *not* invalidate the cache locally.
Expand All @@ -284,6 +332,7 @@ def _send_invalidation_to_replication(
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
assert self._cache_id_gen is not None
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)

Expand All @@ -298,7 +347,7 @@ def _send_invalidation_to_replication(
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

Expand Down
Loading