Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge commit 'd4a7829b1' into anoa/dinsic_release_1_21_x
Browse files Browse the repository at this point in the history
* commit 'd4a7829b1':
  Convert synapse.api to async/await (#8031)
  • Loading branch information
anoadragon453 committed Oct 19, 2020
2 parents af13a4b + d4a7829 commit 383a87e
Show file tree
Hide file tree
Showing 22 changed files with 174 additions and 161 deletions.
1 change: 1 addition & 0 deletions changelog.d/8031.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
126 changes: 58 additions & 68 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import List, Optional, Tuple

import pymacaroons
from netaddr import IPAddress

from twisted.internet import defer
from twisted.web.server import Request

import synapse.types
Expand Down Expand Up @@ -80,28 +79,28 @@ def __init__(self, hs):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key

@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}

room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)

@defer.inlineCallbacks
def check_user_in_room(
async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
):
) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
Expand All @@ -119,37 +118,35 @@ def check_user_in_room(
Raises:
AuthError if the user is/was not in the room.
Returns:
Deferred[Optional[EventBase]]:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None

if membership == Membership.JOIN:
return member
if member:
membership = member.membership

# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if not forgot:
if membership == Membership.JOIN:
return member

# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member

raise AuthError(403, "User %s not in room %s" % (user_id, room_id))

@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids

def can_federate(self, event, auth_events):
Expand All @@ -160,14 +157,13 @@ def can_federate(self, event, auth_events):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)

@defer.inlineCallbacks
def get_user_by_req(
async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
):
) -> synapse.types.Requester:
""" Get a registered user's ID.
Args:
Expand All @@ -180,7 +176,7 @@ def get_user_by_req(
/login will deliver access tokens regardless of expiration.
Returns:
defer.Deferred: resolves to a `synapse.types.Requester` object
Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
Expand All @@ -194,15 +190,14 @@ def get_user_by_req(

access_token = self.get_access_token_from_request(request)

user_id, app_service = yield self._get_appservice_user_id(request)

user_id, app_service = self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)

if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
Expand All @@ -212,7 +207,7 @@ def get_user_by_req(

return synapse.types.create_requester(user_id, app_service=app_service)

user_info = yield self.get_user_by_access_token(
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
Expand All @@ -222,10 +217,11 @@ def get_user_by_req(
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expired = yield self.store.is_account_expired(
user_id, self.clock.time_msec()
)
if expired:
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
Expand All @@ -235,7 +231,7 @@ def get_user_by_req(
device_id = user_info.get("device_id")

if user and access_token and ip_addr:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
Expand Down Expand Up @@ -291,10 +287,9 @@ def _get_appservice_user_id(self, request):
# )
return user_id, app_service

@defer.inlineCallbacks
def get_user_by_access_token(
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
) -> dict:
""" Validate access token and get user_id from it
Args:
Expand All @@ -304,7 +299,7 @@ def get_user_by_access_token(
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
Deferred[dict]: dict that includes:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
Expand All @@ -318,7 +313,7 @@ def get_user_by_access_token(

if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
Expand Down Expand Up @@ -356,7 +351,7 @@ def get_user_by_access_token(
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
Expand Down Expand Up @@ -486,9 +481,8 @@ def _verify_expiry(self, caveat):
now = self.hs.get_clock().time_msec()
return now < expiry

@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None

Expand All @@ -511,7 +505,7 @@ def get_appservice_by_req(self, request):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
return service

async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
Expand All @@ -526,19 +520,19 @@ async def is_server_admin(self, user: UserID) -> bool:

def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
If `for_verification` is False then only return auth events that
should be added to the event's `auth_events`.
Returns:
defer.Deferred(list[str]): List of event IDs.
List of event IDs.
"""

if event.type == EventTypes.Create:
return defer.succeed([])
return []

# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
Expand All @@ -557,7 +551,7 @@ def compute_auth_events(
if auth_ev_id:
auth_ids.append(auth_ev_id)

return defer.succeed(auth_ids)
return auth_ids

async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
Expand Down Expand Up @@ -640,10 +634,9 @@ def get_access_token_from_request(request: Request):

return query_params[0].decode("ascii")

@defer.inlineCallbacks
def check_user_in_room_or_world_readable(
async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
):
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
Expand All @@ -654,10 +647,9 @@ def check_user_in_room_or_world_readable(
members but have now departed
Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If
the user is not in the room and never has been, then
`(Membership.JOIN, None)` is returned.
Resolves to the current membership of the user in the room and the
membership event ID of the user. If the user is not in the room and
never has been, then `(Membership.JOIN, None)` is returned.
"""

try:
Expand All @@ -666,15 +658,13 @@ def check_user_in_room_or_world_readable(
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_in_room(
member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
visibility = await self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility
Expand Down
13 changes: 5 additions & 8 deletions synapse/api/auth_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
Expand All @@ -36,8 +34,7 @@ def __init__(self, hs):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids

@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Expand All @@ -60,7 +57,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
if await self.store.is_support_user(user_id):
return

if self._hs_disabled:
Expand All @@ -76,11 +73,11 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):

# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return

is_trial = yield self.store.is_trial_user(user_id)
is_trial = await self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
Expand All @@ -93,7 +90,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
Expand Down
7 changes: 2 additions & 5 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from canonicaljson import json
from jsonschema import FormatChecker

from twisted.internet import defer

from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
Expand Down Expand Up @@ -137,9 +135,8 @@ def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()

@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
async def get_user_filter(self, user_localpart, filter_id):
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)

def add_user_filter(self, user_localpart, user_filter):
Expand Down
Loading

0 comments on commit 383a87e

Please sign in to comment.