Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/tests/rest/admin #11590

Merged
merged 2 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11590.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse/tests/rest/admin`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True

[mypy-tests.rest.admin.*]
disallow_untyped_defs = True

[mypy-tests.rest.client.test_directory]
disallow_untyped_defs = True

Expand Down
3 changes: 2 additions & 1 deletion tests/rest/admin/test_background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_invalid_parameter(self) -> None:
def _register_bg_update(self) -> None:
"Adds a bg update but doesn't start it"

async def _fake_update(progress, batch_size) -> int:
async def _fake_update(progress: JsonDict, batch_size: int) -> int:
await self.clock.sleep(0.2)
return batch_size

Expand Down
33 changes: 18 additions & 15 deletions tests/rest/admin/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest

Expand All @@ -31,7 +34,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
Expand All @@ -44,7 +47,7 @@ def prepare(self, reactor, clock, hs: HomeServer):
("/_synapse/admin/v1/federation/destinations/dummy",),
]
)
def test_requester_is_no_admin(self, url: str):
def test_requester_is_no_admin(self, url: str) -> None:
"""
If the user is not a server admin, an error 403 is returned.
"""
Expand All @@ -62,7 +65,7 @@ def test_requester_is_no_admin(self, url: str):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

def test_invalid_parameter(self):
def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
"""
Expand Down Expand Up @@ -117,7 +120,7 @@ def test_invalid_parameter(self):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])

def test_limit(self):
def test_limit(self) -> None:
"""
Testing list of destinations with limit
"""
Expand All @@ -137,7 +140,7 @@ def test_limit(self):
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["destinations"])

def test_from(self):
def test_from(self) -> None:
"""
Testing list of destinations with a defined starting point (from)
"""
Expand All @@ -157,7 +160,7 @@ def test_from(self):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["destinations"])

def test_limit_and_from(self):
def test_limit_and_from(self) -> None:
"""
Testing list of destinations with a defined starting point and limit
"""
Expand All @@ -177,7 +180,7 @@ def test_limit_and_from(self):
self.assertEqual(len(channel.json_body["destinations"]), 10)
self._check_fields(channel.json_body["destinations"])

def test_next_token(self):
def test_next_token(self) -> None:
"""
Testing that `next_token` appears at the right place
"""
Expand Down Expand Up @@ -238,7 +241,7 @@ def test_next_token(self):
self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body)

def test_list_all_destinations(self):
def test_list_all_destinations(self) -> None:
"""
List all destinations.
"""
Expand All @@ -259,7 +262,7 @@ def test_list_all_destinations(self):
# Check that all fields are available
self._check_fields(channel.json_body["destinations"])

def test_order_by(self):
def test_order_by(self) -> None:
"""
Testing order list with parameter `order_by`
"""
Expand All @@ -268,7 +271,7 @@ def _order_test(
expected_destination_list: List[str],
order_by: Optional[str],
dir: Optional[str] = None,
):
) -> None:
"""Request the list of destinations in a certain order.
Assert that order is what we expect

Expand Down Expand Up @@ -358,13 +361,13 @@ def _order_test(
[dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
)

def test_search_term(self):
def test_search_term(self) -> None:
"""Test that searching for a destination works correctly"""

def _search_test(
expected_destination: Optional[str],
search_term: str,
):
) -> None:
"""Search for a destination and check that the returned destinationis a match

Args:
Expand Down Expand Up @@ -410,7 +413,7 @@ def _search_test(
_search_test(None, "foo")
_search_test(None, "bar")

def test_get_single_destination(self):
def test_get_single_destination(self) -> None:
"""
Get one specific destinations.
"""
Expand All @@ -429,7 +432,7 @@ def test_get_single_destination(self):
# convert channel.json_body into a List
self._check_fields([channel.json_body])

def _create_destinations(self, number_destinations: int):
def _create_destinations(self, number_destinations: int) -> None:
"""Create a number of destinations

Args:
Expand All @@ -442,7 +445,7 @@ def _create_destinations(self, number_destinations: int):
self.store.set_destination_last_successful_stream_ordering(dest, 100)
)

def _check_fields(self, content: List[JsonDict]):
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected destination attributes are present in content

Args:
Expand Down
4 changes: 3 additions & 1 deletion tests/rest/admin/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,9 @@ def _create_media(self) -> str:

return server_and_media_id

def _access_media(self, server_and_media_id, expect_success=True) -> None:
def _access_media(
self, server_and_media_id: str, expect_success: bool = True
) -> None:
"""
Try to access a media and check the result
"""
Expand Down
25 changes: 16 additions & 9 deletions tests/rest/admin/test_registration_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import random
import string
from http import HTTPStatus
from typing import Optional

from twisted.test.proto_helpers import MemoryReactor

Expand Down Expand Up @@ -42,21 +43,27 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

self.url = "/_synapse/admin/v1/registration_tokens"

def _new_token(self, **kwargs) -> str:
def _new_token(
self,
token: Optional[str] = None,
uses_allowed: Optional[int] = None,
pending: int = 0,
completed: int = 0,
expiry_time: Optional[int] = None,
) -> str:
"""Helper function to create a token."""
token = kwargs.get(
"token",
"".join(random.choices(string.ascii_letters, k=8)),
)
if token is None:
token = "".join(random.choices(string.ascii_letters, k=8))

self.get_success(
self.store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": kwargs.get("uses_allowed", None),
"pending": kwargs.get("pending", 0),
"completed": kwargs.get("completed", 0),
"expiry_time": kwargs.get("expiry_time", None),
"uses_allowed": uses_allowed,
"pending": pending,
"completed": completed,
"expiry_time": expiry_time,
},
)
)
Expand Down
Loading