Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse.events.*. #11066

Merged
merged 11 commits into from
Oct 13, 2021
1 change: 1 addition & 0 deletions changelog.d/11066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.events`.
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ files =
synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
synapse/events/presence_router.py,
synapse/events/snapshot.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/utils.py,
synapse/events/validator.py,
synapse/federation,
synapse/groups,
Expand Down Expand Up @@ -95,6 +98,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.events.*]
disallow_untyped_defs = True

[mypy-synapse.handlers.*]
disallow_untyped_defs = True

Expand Down
4 changes: 2 additions & 2 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ class EventBuilder:
)

@property
def state_key(self):
def state_key(self) -> str:
if self._state_key is not None:
return self._state_key

raise AttributeError("state_key")

def is_state(self):
def is_state(self) -> bool:
return self._state_key is not None

async def build(
Expand Down
16 changes: 8 additions & 8 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Expand All @@ -33,14 +34,13 @@
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
GET_INTERESTED_USERS_CALLBACK = Callable[
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]

logger = logging.getLogger(__name__)


def load_legacy_presence_router(hs: "HomeServer"):
def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
Expand Down Expand Up @@ -69,7 +69,7 @@ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
if f is None:
return None

def run(*args, **kwargs):
def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
clokep marked this conversation as resolved.
Show resolved Hide resolved
# f is definitely not None.
assert f is not None
Expand Down Expand Up @@ -104,7 +104,7 @@ def register_presence_router_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
):
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
Expand Down Expand Up @@ -142,7 +142,7 @@ async def get_users_for_states(
# Don't include any extra destinations for presence updates
return {}

users_for_states = {}
users_for_states: Dict[str, Set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
Expand Down Expand Up @@ -171,7 +171,7 @@ async def get_users_for_states(

return users_for_states

async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Expand Down
108 changes: 58 additions & 50 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import attr
from frozendict import frozendict

from twisted.internet.defer import Deferred

from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap

if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore


Expand Down Expand Up @@ -112,13 +115,13 @@ class EventContext:

@staticmethod
def with_state(
state_group,
state_group_before_event,
current_state_ids,
prev_state_ids,
prev_group=None,
delta_ids=None,
):
state_group: Optional[int],
state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
return EventContext(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
Expand All @@ -129,22 +132,22 @@ def with_state(
)

@staticmethod
def for_outlier():
def for_outlier() -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(
current_state_ids={},
prev_state_ids={},
)

async def serialize(self, event: EventBase, store: "DataStore") -> dict:
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`

Args:
event (FrozenEvent): The event that this context relates to
event: The event that this context relates to

Returns:
dict
The serialized event.
"""

# We don't serialize the full state dicts, instead they get pulled out
Expand All @@ -170,17 +173,16 @@ async def serialize(self, event: EventBase, store: "DataStore") -> dict:
}

@staticmethod
def deserialize(storage, input):
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.

Args:
storage (Storage): Used to convert AS ID to AS object and fetch
state.
input (dict): A dict produced by `serialize`
storage: Used to convert AS ID to AS object and fetch state.
input: A dict produced by `serialize`

Returns:
EventContext
The event context.
"""
context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
Expand Down Expand Up @@ -241,39 +243,41 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
await self._ensure_fetched()
return self._current_state_ids

async def get_prev_state_ids(self):
async def get_prev_state_ids(self) -> StateMap[str]:
"""
Gets the room state map, excluding this event.

For a non-state event, this will be the same as get_current_state_ids().

Returns:
dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
Returns {} if state_group is None, which happens when the associated
event is an outlier.

Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
await self._ensure_fetched()
# There *should* be previous state IDs now.
assert self._prev_state_ids is not None
return self._prev_state_ids

def get_cached_current_state_ids(self):
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
"""Gets the current state IDs if we have them already cached.

It is an error to access this for a rejected event, since rejected state should
not make it into the room state. This method will raise an exception if
``rejected`` is set.

Returns:
dict[(str, str), str]|None: Returns None if we haven't cached the
state or if state_group is None, which happens when the associated
event is an outlier.
Returns None if we haven't cached the state or if state_group is None,
which happens when the associated event is an outlier.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

return self._current_state_ids

async def _ensure_fetched(self):
async def _ensure_fetched(self) -> None:
return None


Expand All @@ -285,57 +289,59 @@ class _AsyncEventContextImpl(EventContext):

Attributes:

_storage (Storage)
_storage

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.
_event_type: The type of the event the context is associated with.

_event_state_key (str): The state_key of the event the context is
associated with.
_event_state_key: The state_key of the event the context is associated with.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
_prev_state_id: If the event associated with the context is a state event,
then `_prev_state_id` is the event_id of the state that was replaced.
"""

# This needs to have a default as we're inheriting
_storage = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
_storage: "Storage" = attr.ib(default=None)
_prev_state_id: Optional[str] = attr.ib(default=None)
_event_type: str = attr.ib(default=None)
_event_state_key: Optional[str] = attr.ib(default=None)
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)

async def _ensure_fetched(self):
async def _ensure_fetched(self) -> None:
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)

return await make_deferred_yieldable(self._fetching_state_deferred)
await make_deferred_yieldable(self._fetching_state_deferred)

async def _fill_out_state(self):
async def _fill_out_state(self) -> None:
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self._current_state_ids = await self._storage.state.get_state_ids_for_group(
current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
# Set this separately so mypy knows current_state_ids is not None.
self._current_state_ids = current_state_ids
if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
self._prev_state_ids = dict(current_state_ids)

key = (self._event_type, self._event_state_key)
if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = self._current_state_ids
self._prev_state_ids = current_state_ids


def _encode_state_dict(state_dict):
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
Expand All @@ -345,7 +351,9 @@ def _encode_state_dict(state_dict):
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]


def _decode_state_dict(input):
def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]]
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
return None
Expand Down
14 changes: 8 additions & 6 deletions synapse/events/spamcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
]


def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
"""Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement.
"""
Expand Down Expand Up @@ -146,7 +146,7 @@ def wrapper(
"Bad signature for callback check_registration_for_spam",
)

def run(*args, **kwargs):
def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# wrapped_func is definitely not None.
assert wrapped_func is not None
Expand All @@ -165,7 +165,7 @@ def run(*args, **kwargs):


class SpamChecker:
def __init__(self):
def __init__(self) -> None:
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
Expand Down Expand Up @@ -209,7 +209,7 @@ def register_callbacks(
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
Expand Down Expand Up @@ -275,7 +275,9 @@ async def check_event_for_spam(

return False

async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
async def user_may_join_room(
self, user_id: str, room_id: str, is_invited: bool
) -> bool:
"""Checks if a given users is allowed to join a room.
Not called when a user creates a room.

Expand All @@ -285,7 +287,7 @@ async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool)
is_invited: Whether the user is invited into the room

Returns:
bool: Whether the user may join the room
Whether the user may join the room
"""
for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False:
Expand Down
Loading