diff --git a/iolite_client/oauth_handler.py b/iolite_client/oauth_handler.py index 60d410d..828fd2a 100644 --- a/iolite_client/oauth_handler.py +++ b/iolite_client/oauth_handler.py @@ -44,6 +44,13 @@ def get_sid_query(access_token: str) -> str: } ) + @staticmethod + def add_expires_at(token: dict) -> dict: + expires_at = time.time() + token["expires_in"] + token.update({"expires_at": expires_at}) + del token["expires_in"] + return token + class OAuthHandler: def __init__(self, username: str, password: str): @@ -62,7 +69,7 @@ def get_access_token(self, code: str, name: str) -> dict: f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password) ) response.raise_for_status() - return json.loads(response.text) + return OAuthHandlerHelper.add_expires_at(json.loads(response.text)) def get_new_access_token(self, refresh_token: str) -> dict: """ @@ -75,7 +82,7 @@ def get_new_access_token(self, refresh_token: str) -> dict: f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password) ) response.raise_for_status() - return json.loads(response.text) + return OAuthHandlerHelper.add_expires_at(json.loads(response.text)) def get_sid(self, access_token: str) -> str: """ @@ -112,7 +119,7 @@ async def get_access_token(self, code: str, name: str) -> dict: auth=aiohttp.BasicAuth(self.username, self.password), ) response.raise_for_status() - return await response.json() + return OAuthHandlerHelper.add_expires_at(await response.json()) async def get_new_access_token(self, refresh_token: str) -> dict: """ @@ -126,7 +133,7 @@ async def get_new_access_token(self, refresh_token: str) -> dict: auth=aiohttp.BasicAuth(self.username, self.password), ) response.raise_for_status() - return await response.json() + return OAuthHandlerHelper.add_expires_at(await response.json()) async def get_sid(self, access_token: str) -> str: """ @@ -165,8 +172,6 @@ def __init__(self, path: str): self.path = path def store_access_token(self, payload: dict): - expires_at = time.time() + payload["expires_in"] - payload.update({"expires_at": expires_at}) self.__store("access_token", payload) def fetch_access_token(self) -> Optional[dict]: @@ -200,39 +205,29 @@ def __init__( self.oauth_handler = oauth_handler self.oauth_storage = oauth_storage - def get_sid(self, code: str, name: str) -> str: + def get_sid(self, token: dict) -> str: """ - Get SID by providing the initial pairing code and the device name you would like to register. + Get SID by access_token. - :param code: The code provided in the QR code - :param name: The name of the device you want to register + :param token: The access token :return: """ - access_token = self.oauth_storage.fetch_access_token() - - if access_token is None: - logger.debug("No token, requesting") - access_token = self.oauth_handler.get_access_token(code, name) - self.oauth_storage.store_access_token(access_token) - - expires_at = access_token["expires_at"] - - if expires_at < time.time(): + if token["expires_at"] < time.time(): logger.debug("Token expired, refreshing") - token = self._refresh_access_token(access_token) + access_token = self._refresh_access_token(token) else: - token = access_token["access_token"] + access_token = token["access_token"] try: - return self.oauth_handler.get_sid(token) + return self.oauth_handler.get_sid(access_token) except requests.exceptions.HTTPError as e: logger.debug(f"Invalid token, attempt refresh: {e}") - token = self._refresh_access_token(access_token) - return self.oauth_handler.get_sid(token) + access_token = self._refresh_access_token(token) + return self.oauth_handler.get_sid(access_token) - def _refresh_access_token(self, access_token): + def _refresh_access_token(self, token: dict) -> str: refreshed_token = self.oauth_handler.get_new_access_token( - access_token["refresh_token"] + token["refresh_token"] ) self.oauth_storage.store_access_token(refreshed_token) return refreshed_token["access_token"] @@ -247,37 +242,33 @@ def __init__( self.oauth_handler = oauth_handler self.oauth_storage = oauth_storage - async def get_sid(self, code: str, name: str) -> str: - access_token = await self.oauth_storage.fetch_access_token() - - if access_token is None: - logger.debug("No token, requesting") - access_token = await self.oauth_handler.get_access_token(code, name) - await self.oauth_storage.store_access_token(access_token) + async def get_sid(self, token: dict) -> str: + """ + Get SID by token. - if access_token["expires_at"] < time.time(): + :param token: The token + :return: + """ + if token["expires_at"] < time.time(): logger.debug("Token expired, refreshing") - token = await self._refresh_token(access_token) + access_token = await self._refresh_token(token) else: - token = access_token["access_token"] + access_token = token["access_token"] logger.debug("Fetched access token") try: - return await self.oauth_handler.get_sid(token) + return await self.oauth_handler.get_sid(access_token) except BaseException as e: logger.debug(f"Invalid token, attempt refresh: {e}") - token = await self._refresh_token(access_token) - return await self.oauth_handler.get_sid(token) + access_token = await self._refresh_token(token) + return await self.oauth_handler.get_sid(access_token) - async def _refresh_token(self, access_token: dict) -> str: + async def _refresh_token(self, token: dict) -> str: """Refresh token.""" refreshed_token = await self.oauth_handler.get_new_access_token( - access_token["refresh_token"] + token["refresh_token"] ) - expires_at = time.time() + refreshed_token["expires_in"] - refreshed_token.update({"expires_at": expires_at}) - del refreshed_token["expires_in"] await self.oauth_storage.store_access_token(refreshed_token) return refreshed_token["access_token"] diff --git a/poetry.lock b/poetry.lock index 3a0d3c6..5100f8b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -358,6 +358,21 @@ tomli = ">=1.0.0" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.18.1" +description = "Pytest support for asyncio" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pytest = ">=6.1.0" +typing-extensions = {version = ">=3.7.2", markers = "python_version < \"3.8\""} + +[package.extras] +testing = ["coverage (==6.2)", "hypothesis (>=5.7.1)", "flaky (>=3.5.0)", "mypy (==0.931)"] + [[package]] name = "pytest-cov" version = "3.0.0" @@ -561,7 +576,7 @@ dev = ["environs"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "7fff54def948819f26d862f00ae0ba26b90a8f997763c7306974ca89cc20db86" +content-hash = "9498aef5536c5be7ff2f220d0b8c7c8bd47e3706b76b731702a780d3ddc7abaf" [metadata.files] aiohttp = [ @@ -911,6 +926,10 @@ pytest = [ {file = "pytest-7.0.1-py3-none-any.whl", hash = "sha256:9ce3ff477af913ecf6321fe337b93a2c0dcf2a0a1439c43f5452112c1e4280db"}, {file = "pytest-7.0.1.tar.gz", hash = "sha256:e30905a0c131d3d94b89624a1cc5afec3e0ba2fbdb151867d8e0ebd49850f171"}, ] +pytest-asyncio = [ + {file = "pytest-asyncio-0.18.1.tar.gz", hash = "sha256:c43fcdfea2335dd82ffe0f2774e40285ddfea78a8e81e56118d47b6a90fbb09e"}, + {file = "pytest_asyncio-0.18.1-py3-none-any.whl", hash = "sha256:c9ec48e8bbf5cc62755e18c4d8bc6907843ec9c5f4ac8f61464093baeba24a7e"}, +] pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, diff --git a/pyproject.toml b/pyproject.toml index 3b935eb..970283e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ responses = "*" pytest-socket = "*" freezegun = "*" aioresponses = "^0.7.2" +pytest-asyncio = "^0.18.1" [tool.poetry.extras] dev = ["environs"] diff --git a/pytest.ini b/pytest.ini index 5af7026..6f5a510 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] addopts = --disable-socket +asyncio_mode = strict diff --git a/scripts/example.py b/scripts/example.py index 8882240..ab88f10 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -23,7 +23,13 @@ oauth_storage = LocalOAuthStorage(".") oauth_handler = OAuthHandler(USERNAME, PASSWORD) oauth_wrapper = OAuthWrapper(oauth_handler, oauth_storage) -sid = oauth_wrapper.get_sid(CODE, NAME) + +access_token = oauth_storage.fetch_access_token() +if not access_token: + access_token = oauth_handler.get_access_token(CODE, NAME) + oauth_storage.store_access_token(access_token) + +sid = oauth_wrapper.get_sid(access_token) print("------------------") print(f"URL: https://remote.iolite.de/ui/?SID={sid}") diff --git a/test/test_oauth_handler.py b/test/test_oauth_handler.py index 07318f5..50e031e 100644 --- a/test/test_oauth_handler.py +++ b/test/test_oauth_handler.py @@ -1,4 +1,3 @@ -import asyncio import datetime import json import unittest @@ -46,29 +45,20 @@ def test_get_access_token_valid(self): class AsyncOAuthHandlerTest(unittest.TestCase): - @staticmethod - def _run_async(callback): - loop = asyncio.get_event_loop() - loop.run_until_complete(callback()) - + @pytest.mark.asyncio @pytest.mark.enable_socket - def test_get_access_token_invalid_credentials(self): - + async def test_get_access_token_invalid_credentials(self): with aioresponses() as m: m.post("https://remote.iolite.de/ui/token", status=403) - async def async_callback(): - async with aiohttp.ClientSession() as web_session: - oauth_handler = AsyncOAuthHandler("user", "password", web_session) - with self.assertRaises( - aiohttp.client_exceptions.ClientConnectionError - ): - await oauth_handler.get_access_token("dodgy-code", "my-device") - - self._run_async(async_callback) + async with aiohttp.ClientSession() as web_session: + oauth_handler = AsyncOAuthHandler("user", "password", web_session) + with self.assertRaises(aiohttp.client_exceptions.ClientConnectionError): + await oauth_handler.get_access_token("dodgy-code", "my-device") + @pytest.mark.asyncio @pytest.mark.enable_socket - def test_get_access_token_valid(self): + async def test_get_access_token_valid(self): with aioresponses() as m: query = OAuthHandlerHelper.get_access_token_query("real-code", "my-device") m.post( @@ -84,68 +74,35 @@ def test_get_access_token_valid(self): ), ) - async def async_callback(): - async with aiohttp.ClientSession() as web_session: - oauth_handler = AsyncOAuthHandler("user", "password", web_session) - response = await oauth_handler.get_access_token( - "real-code", "my-device" - ) - self.assertIsInstance(response, dict) - - self._run_async(async_callback) + async with aiohttp.ClientSession() as web_session: + oauth_handler = AsyncOAuthHandler("user", "password", web_session) + response = await oauth_handler.get_access_token( + "real-code", "my-device" + ) + self.assertIsInstance(response, dict) class OAuthWrapperTest(unittest.TestCase): def setUp(self) -> None: self.mock_oauth_handler = Mock() self.mock_oauth_storage = Mock() + self.oauth_wrapper = OAuthWrapper( + self.mock_oauth_handler, self.mock_oauth_storage + ) @freeze_time("2021-01-01 00:00:00") def test_get_sid_valid_access_token(self): - self.mock_oauth_storage.fetch_access_token.return_value = { - "expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(), - "access_token": "access-token", - } - - oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage) - - oauth_wrapper.get_sid("my-code", "my-device") - self.mock_oauth_handler.get_sid.assert_called_once_with("access-token") - - @freeze_time("2021-01-01 00:00:00") - def test_get_sid_nothing_stored(self): - self.mock_oauth_storage.fetch_access_token.return_value = None - - response = { - "expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(), - "access_token": "access-token", - } - - self.mock_oauth_handler.get_access_token.return_value = response - - oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage) - - oauth_wrapper.get_sid("my-code", "my-device") - self.mock_oauth_storage.store_access_token.assert_called_once_with(response) + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1)) + self.oauth_wrapper.get_sid(token) self.mock_oauth_handler.get_sid.assert_called_once_with("access-token") @freeze_time("2021-01-01 00:00:01") def test_get_sid_expired_access_token(self): - self.mock_oauth_storage.fetch_access_token.return_value = { - "expires_at": datetime.datetime(2021, 1, 1, 0, 0, 0).timestamp(), - "access_token": "access-token", - "refresh_token": "refresh-token", - } - - response = { - "expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(), - "access_token": "access-token", - } - + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 0)) + response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0)) self.mock_oauth_handler.get_new_access_token.return_value = response - oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage) - oauth_wrapper.get_sid("my-code", "my-device") + self.oauth_wrapper.get_sid(token) self.mock_oauth_handler.get_new_access_token.assert_called_once_with( "refresh-token" ) @@ -154,25 +111,28 @@ def test_get_sid_expired_access_token(self): @freeze_time("2021-01-01 00:00:00") def test_invalid_token_refresh(self): - self.mock_oauth_storage.fetch_access_token.return_value = { - "expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(), - "access_token": "access-token", - "refresh_token": "refresh-token", - } + + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1)) + self.mock_oauth_storage.fetch_access_token.return_value = token self.mock_oauth_handler.get_sid.side_effect = [ HTTPError("Something went wrong"), "sid", ] - response = { - "expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(), - "access_token": "access-token", - } + response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0)) self.mock_oauth_handler.get_new_access_token.return_value = response oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage) - oauth_wrapper.get_sid("my-code", "my-device") + oauth_wrapper.get_sid(token) self.assertEqual(self.mock_oauth_handler.get_sid.call_count, 2) + + @staticmethod + def _get_token(date_time: datetime.datetime) -> dict: + return { + "expires_at": date_time.timestamp(), + "access_token": "access-token", + "refresh_token": "refresh-token", + }