Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge commit '9b7ac03af' into anoa/dinsic_release_1_21_x
Browse files Browse the repository at this point in the history
* commit '9b7ac03af':
  Convert calls of async database methods to async (#8166)
  simple_search_list_txn should return None, not 0. (#8187)
  Fix missing _add_persisted_position (#8179)
  • Loading branch information
anoadragon453 committed Oct 20, 2020
2 parents 118f41d + 9b7ac03 commit ec50c99
Show file tree
Hide file tree
Showing 19 changed files with 171 additions and 93 deletions.
1 change: 1 addition & 0 deletions changelog.d/8166.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
1 change: 1 addition & 0 deletions changelog.d/8179.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.
1 change: 1 addition & 0 deletions changelog.d/8187.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.storage.database`.
16 changes: 9 additions & 7 deletions synapse/federation/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import logging

from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
from synapse.types import JsonDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,15 +51,15 @@ def have_responded(self, origin, transaction):
return self.store.get_received_txn_response(transaction.transaction_id, origin)

@log_function
def set_response(self, origin, transaction, code, response):
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
""" Persist how we responded to a transaction.
Returns:
Deferred
"""
if not transaction.transaction_id:
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")

return self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response
await self.store.set_received_txn_response(
transaction_id, origin, code, response
)
4 changes: 1 addition & 3 deletions synapse/federation/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs):
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]

super(Transaction, self).__init__(
transaction_id=transaction_id, pdus=pdus, **kwargs
)
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)

@staticmethod
def create_new(pdus, **kwargs):
Expand Down
7 changes: 3 additions & 4 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Optional,
Tuple,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -1655,7 +1654,7 @@ def simple_search_list_txn(
term: Optional[str],
col: str,
retcols: Iterable[str],
) -> Union[List[Dict[str, Any]], int]:
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Expand All @@ -1667,14 +1666,14 @@ def simple_search_list_txn(
retcols: the names of the columns to return
Returns:
0 if no term is given, otherwise a list of dictionaries.
None if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
return 0
return None

return cls.cursor_to_dict(txn)

Expand Down
6 changes: 2 additions & 4 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ async def get_appservice_state(self, service):
return result.get("state")
return None

def set_appservice_state(self, service, state):
async def set_appservice_state(self, service, state) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,11 @@ async def get_user_ids_requiring_device_list_resync(

return {row["user_id"] for row in rows}

def mark_remote_user_device_cache_as_stale(self, user_id: str):
async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
Expand Down
30 changes: 22 additions & 8 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,13 @@ def remove_room_from_summary(self, group_id, room_id, category_id):
desc="remove_room_from_summary",
)

def upsert_group_category(self, group_id, category_id, profile, is_public):
async def upsert_group_category(
self,
group_id: str,
category_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/update room category for group
"""
insertion_values = {}
Expand All @@ -758,7 +764,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public):
else:
update_values["is_public"] = is_public

return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
Expand All @@ -773,7 +779,13 @@ def remove_group_category(self, group_id, category_id):
desc="remove_group_category",
)

def upsert_group_role(self, group_id, role_id, profile, is_public):
async def upsert_group_role(
self,
group_id: str,
role_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/remove user role
"""
insertion_values = {}
Expand All @@ -789,7 +801,7 @@ def upsert_group_role(self, group_id, role_id, profile, is_public):
else:
update_values["is_public"] = is_public

return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
Expand Down Expand Up @@ -938,10 +950,10 @@ def remove_user_from_summary(self, group_id, user_id, role_id):
desc="remove_user_from_summary",
)

def add_group_invite(self, group_id, user_id):
async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
Expand Down Expand Up @@ -1044,8 +1056,10 @@ def _remove_user_from_group_txn(txn):
"remove_user_from_group", _remove_user_from_group_txn
)

def add_room_to_group(self, group_id, room_id, is_public):
return self.db_pool.simple_insert(
async def add_room_to_group(
self, group_id: str, room_id: str, is_public: bool
) -> None:
await self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
Expand Down
26 changes: 16 additions & 10 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,28 @@ async def store_server_verify_keys(
for i in invalidations:
invalidate((i,))

def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
):
async def store_server_keys_json(
self,
server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name (str): The name of the server.
key_id (str): The identifer of the key this JSON is for.
from_server (str): The server this JSON was fetched from.
ts_now_ms (int): The time now in milliseconds.
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
server_name: The name of the server.
key_id: The identifer of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
key_json_bytes: The encoded JSON.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
Expand Down
22 changes: 11 additions & 11 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
desc="get_local_media",
)

def store_local_media(
async def store_local_media(
self,
media_id,
media_type,
Expand All @@ -69,8 +69,8 @@ def store_local_media(
media_length,
user_id,
url_cache=None,
):
return self.db_pool.simple_insert(
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
Expand Down Expand Up @@ -141,10 +141,10 @@ def get_url_cache_txn(txn):

return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)

def store_url_cache(
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
Expand Down Expand Up @@ -172,7 +172,7 @@ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]
desc="get_local_media_thumbnails",
)

def store_local_thumbnail(
async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
Expand All @@ -181,7 +181,7 @@ def store_local_thumbnail(
thumbnail_method,
thumbnail_length,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
Expand Down Expand Up @@ -212,7 +212,7 @@ async def get_cached_remote_media(
desc="get_cached_remote_media",
)

def store_cached_remote_media(
async def store_cached_remote_media(
self,
origin,
media_id,
Expand All @@ -222,7 +222,7 @@ def store_cached_remote_media(
upload_name,
filesystem_id,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
Expand Down Expand Up @@ -288,7 +288,7 @@ async def get_remote_media_thumbnails(
desc="get_remote_media_thumbnails",
)

def store_remote_media_thumbnail(
async def store_remote_media_thumbnail(
self,
origin,
media_id,
Expand All @@ -299,7 +299,7 @@ def store_remote_media_thumbnail(
thumbnail_method,
thumbnail_length,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@


class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self.db_pool.simple_insert(
async def insert_open_id_token(
self, token: str, ts_valid_until_ms: int, user_id: str
) -> None:
await self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
Expand Down
13 changes: 7 additions & 6 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ async def get_from_remote_profile_cache(
desc="get_from_remote_profile_cache",
)

def create_profile(self, user_localpart):
return self.db_pool.simple_insert(
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)

Expand Down Expand Up @@ -197,8 +197,7 @@ async def set_profiles_active(

class ProfileStore(ProfileWorkerStore):
def __init__(self, database, db_conn, hs):

super(ProfileStore, self).__init__(database, db_conn, hs)
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
"profile_replication_status_host_index",
Expand All @@ -208,13 +207,15 @@ def __init__(self, database, db_conn, hs):
unique=True,
)

def add_remote_profile_cache(self, user_id, displayname, avatar_url):
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
Expand Down
Loading

0 comments on commit ec50c99

Please sign in to comment.