Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert push to async/await. #7948

Merged
merged 7 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7948.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert push to async/await.
7 changes: 2 additions & 5 deletions synapse/push/action_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.util.metrics import Measure

from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
Expand All @@ -37,7 +35,6 @@ def __init__(self, hs):
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).

@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
async def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
yield self.bulk_evaluator.action_for_event_by_user(event, context)
await self.bulk_evaluator.action_for_event_by_user(event, context)
62 changes: 25 additions & 37 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

from prometheus_client import Counter

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY
Expand Down Expand Up @@ -70,28 +68,27 @@ def __init__(self, hs):
resizable=False,
)

@defer.inlineCallbacks
def _get_rules_for_event(self, event, context):
async def _get_rules_for_event(self, event, context):
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.

Returns:
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = yield self._get_rules_for_room(room_id)
rules_for_room = await self._get_rules_for_room(room_id)

rules_by_user = yield rules_for_room.get_rules(event, context)
rules_by_user = await rules_for_room.get_rules(event, context)

# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited)
has_pusher = await self.store.user_has_pusher(invited)
if has_pusher:
rules_by_user = dict(rules_by_user)
rules_by_user[invited] = yield self.store.get_push_rules_for_user(
rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited
)

Expand All @@ -114,20 +111,19 @@ def _get_rules_for_room(self, room_id):
self.room_push_rule_cache_metrics,
)

@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = yield context.get_prev_state_ids()
async def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id)
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}

sender_level = get_user_power_level(event.sender, auth_events)
Expand All @@ -136,23 +132,19 @@ def _get_power_levels_and_sender_level(self, event, context):

return pl_event.content if pl_event else {}, sender_level

@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table.

Returns:
Deferred
"""
rules_by_user = yield self._get_rules_for_event(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}

room_members = yield self.store.get_joined_users_from_context(event, context)
room_members = await self.store.get_joined_users_from_context(event, context)

(
power_levels,
sender_power_level,
) = yield self._get_power_levels_and_sender_level(event, context)
) = await self._get_power_levels_and_sender_level(event, context)

evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
Expand All @@ -165,7 +157,7 @@ def action_for_event_by_user(self, event, context):
continue

if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid)
is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored:
continue

Expand Down Expand Up @@ -197,7 +189,7 @@ def action_for_event_by_user(self, event, context):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)


def _condition_checker(evaluator, conditions, uid, display_name, cache):
Expand Down Expand Up @@ -274,8 +266,7 @@ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metri
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)

@defer.inlineCallbacks
def get_rules(self, event, context):
async def get_rules(self, event, context):
"""Given an event context return the rules for all users who are
currently in the room.
"""
Expand All @@ -286,7 +277,7 @@ def get_rules(self, event, context):
self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user

with (yield self.linearizer.queue(())):
with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
Expand All @@ -304,9 +295,7 @@ def get_rules(self, event, context):

push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()

push_rules_state_size_counter.inc(len(current_state_ids))
Expand Down Expand Up @@ -353,7 +342,7 @@ def get_rules(self, event, context):
# If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids(
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
Expand All @@ -371,8 +360,7 @@ def get_rules(self, event, context):
)
return ret_rules_by_user

@defer.inlineCallbacks
def _update_rules_with_member_event_ids(
async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event
):
"""Update the partially filled rules_by_user dict by fetching rules for
Expand All @@ -388,7 +376,7 @@ def _update_rules_with_member_event_ids(
"""
sequence = self.sequence

rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())

members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}

Expand All @@ -410,7 +398,7 @@ def _update_rules_with_member_event_ids(

logger.debug("Joined: %r", interested_in_user_ids)

if_users_with_pushers = yield self.store.get_if_users_have_pushers(
if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)

Expand All @@ -420,7 +408,7 @@ def _update_rules_with_member_event_ids(

logger.debug("With pushers: %r", user_ids)

users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)

Expand All @@ -431,7 +419,7 @@ def _update_rules_with_member_event_ids(
if uid in interested_in_user_ids:
user_ids.add(uid)

rules_by_user = yield self.store.bulk_get_push_rules(
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)

Expand Down
Loading