Skip to content

Commit

Permalink
Rework OAuth aspects (#136)
Browse files Browse the repository at this point in the history
* Rework

* Tidy tests
  • Loading branch information
inverse authored Feb 20, 2022
1 parent 60fab2b commit 36d4a7a
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 122 deletions.
81 changes: 36 additions & 45 deletions iolite_client/oauth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def get_sid_query(access_token: str) -> str:
}
)

@staticmethod
def add_expires_at(token: dict) -> dict:
expires_at = time.time() + token["expires_in"]
token.update({"expires_at": expires_at})
del token["expires_in"]
return token


class OAuthHandler:
def __init__(self, username: str, password: str):
Expand All @@ -62,7 +69,7 @@ def get_access_token(self, code: str, name: str) -> dict:
f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password)
)
response.raise_for_status()
return json.loads(response.text)
return OAuthHandlerHelper.add_expires_at(json.loads(response.text))

def get_new_access_token(self, refresh_token: str) -> dict:
"""
Expand All @@ -75,7 +82,7 @@ def get_new_access_token(self, refresh_token: str) -> dict:
f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password)
)
response.raise_for_status()
return json.loads(response.text)
return OAuthHandlerHelper.add_expires_at(json.loads(response.text))

def get_sid(self, access_token: str) -> str:
"""
Expand Down Expand Up @@ -112,7 +119,7 @@ async def get_access_token(self, code: str, name: str) -> dict:
auth=aiohttp.BasicAuth(self.username, self.password),
)
response.raise_for_status()
return await response.json()
return OAuthHandlerHelper.add_expires_at(await response.json())

async def get_new_access_token(self, refresh_token: str) -> dict:
"""
Expand All @@ -126,7 +133,7 @@ async def get_new_access_token(self, refresh_token: str) -> dict:
auth=aiohttp.BasicAuth(self.username, self.password),
)
response.raise_for_status()
return await response.json()
return OAuthHandlerHelper.add_expires_at(await response.json())

async def get_sid(self, access_token: str) -> str:
"""
Expand Down Expand Up @@ -165,8 +172,6 @@ def __init__(self, path: str):
self.path = path

def store_access_token(self, payload: dict):
expires_at = time.time() + payload["expires_in"]
payload.update({"expires_at": expires_at})
self.__store("access_token", payload)

def fetch_access_token(self) -> Optional[dict]:
Expand Down Expand Up @@ -200,39 +205,29 @@ def __init__(
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

def get_sid(self, code: str, name: str) -> str:
def get_sid(self, token: dict) -> str:
"""
Get SID by providing the initial pairing code and the device name you would like to register.
Get SID by access_token.
:param code: The code provided in the QR code
:param name: The name of the device you want to register
:param token: The access token
:return:
"""
access_token = self.oauth_storage.fetch_access_token()

if access_token is None:
logger.debug("No token, requesting")
access_token = self.oauth_handler.get_access_token(code, name)
self.oauth_storage.store_access_token(access_token)

expires_at = access_token["expires_at"]

if expires_at < time.time():
if token["expires_at"] < time.time():
logger.debug("Token expired, refreshing")
token = self._refresh_access_token(access_token)
access_token = self._refresh_access_token(token)
else:
token = access_token["access_token"]
access_token = token["access_token"]

try:
return self.oauth_handler.get_sid(token)
return self.oauth_handler.get_sid(access_token)
except requests.exceptions.HTTPError as e:
logger.debug(f"Invalid token, attempt refresh: {e}")
token = self._refresh_access_token(access_token)
return self.oauth_handler.get_sid(token)
access_token = self._refresh_access_token(token)
return self.oauth_handler.get_sid(access_token)

def _refresh_access_token(self, access_token):
def _refresh_access_token(self, token: dict) -> str:
refreshed_token = self.oauth_handler.get_new_access_token(
access_token["refresh_token"]
token["refresh_token"]
)
self.oauth_storage.store_access_token(refreshed_token)
return refreshed_token["access_token"]
Expand All @@ -247,37 +242,33 @@ def __init__(
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

async def get_sid(self, code: str, name: str) -> str:
access_token = await self.oauth_storage.fetch_access_token()

if access_token is None:
logger.debug("No token, requesting")
access_token = await self.oauth_handler.get_access_token(code, name)
await self.oauth_storage.store_access_token(access_token)
async def get_sid(self, token: dict) -> str:
"""
Get SID by token.
if access_token["expires_at"] < time.time():
:param token: The token
:return:
"""
if token["expires_at"] < time.time():
logger.debug("Token expired, refreshing")
token = await self._refresh_token(access_token)
access_token = await self._refresh_token(token)
else:
token = access_token["access_token"]
access_token = token["access_token"]

logger.debug("Fetched access token")

try:
return await self.oauth_handler.get_sid(token)
return await self.oauth_handler.get_sid(access_token)
except BaseException as e:
logger.debug(f"Invalid token, attempt refresh: {e}")
token = await self._refresh_token(access_token)
return await self.oauth_handler.get_sid(token)
access_token = await self._refresh_token(token)
return await self.oauth_handler.get_sid(access_token)

async def _refresh_token(self, access_token: dict) -> str:
async def _refresh_token(self, token: dict) -> str:
"""Refresh token."""
refreshed_token = await self.oauth_handler.get_new_access_token(
access_token["refresh_token"]
token["refresh_token"]
)
expires_at = time.time() + refreshed_token["expires_in"]
refreshed_token.update({"expires_at": expires_at})
del refreshed_token["expires_in"]
await self.oauth_storage.store_access_token(refreshed_token)

return refreshed_token["access_token"]
21 changes: 20 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ responses = "*"
pytest-socket = "*"
freezegun = "*"
aioresponses = "^0.7.2"
pytest-asyncio = "^0.18.1"

[tool.poetry.extras]
dev = ["environs"]
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
addopts = --disable-socket
asyncio_mode = strict
8 changes: 7 additions & 1 deletion scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
oauth_storage = LocalOAuthStorage(".")
oauth_handler = OAuthHandler(USERNAME, PASSWORD)
oauth_wrapper = OAuthWrapper(oauth_handler, oauth_storage)
sid = oauth_wrapper.get_sid(CODE, NAME)

access_token = oauth_storage.fetch_access_token()
if not access_token:
access_token = oauth_handler.get_access_token(CODE, NAME)
oauth_storage.store_access_token(access_token)

sid = oauth_wrapper.get_sid(access_token)

print("------------------")
print(f"URL: https://remote.iolite.de/ui/?SID={sid}")
Expand Down
110 changes: 35 additions & 75 deletions test/test_oauth_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import datetime
import json
import unittest
Expand Down Expand Up @@ -46,29 +45,20 @@ def test_get_access_token_valid(self):


class AsyncOAuthHandlerTest(unittest.TestCase):
@staticmethod
def _run_async(callback):
loop = asyncio.get_event_loop()
loop.run_until_complete(callback())

@pytest.mark.asyncio
@pytest.mark.enable_socket
def test_get_access_token_invalid_credentials(self):

async def test_get_access_token_invalid_credentials(self):
with aioresponses() as m:
m.post("https://remote.iolite.de/ui/token", status=403)

async def async_callback():
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
with self.assertRaises(
aiohttp.client_exceptions.ClientConnectionError
):
await oauth_handler.get_access_token("dodgy-code", "my-device")

self._run_async(async_callback)
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
with self.assertRaises(aiohttp.client_exceptions.ClientConnectionError):
await oauth_handler.get_access_token("dodgy-code", "my-device")

@pytest.mark.asyncio
@pytest.mark.enable_socket
def test_get_access_token_valid(self):
async def test_get_access_token_valid(self):
with aioresponses() as m:
query = OAuthHandlerHelper.get_access_token_query("real-code", "my-device")
m.post(
Expand All @@ -84,68 +74,35 @@ def test_get_access_token_valid(self):
),
)

async def async_callback():
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
response = await oauth_handler.get_access_token(
"real-code", "my-device"
)
self.assertIsInstance(response, dict)

self._run_async(async_callback)
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
response = await oauth_handler.get_access_token(
"real-code", "my-device"
)
self.assertIsInstance(response, dict)


class OAuthWrapperTest(unittest.TestCase):
def setUp(self) -> None:
self.mock_oauth_handler = Mock()
self.mock_oauth_storage = Mock()
self.oauth_wrapper = OAuthWrapper(
self.mock_oauth_handler, self.mock_oauth_storage
)

@freeze_time("2021-01-01 00:00:00")
def test_get_sid_valid_access_token(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
}

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)

oauth_wrapper.get_sid("my-code", "my-device")
self.mock_oauth_handler.get_sid.assert_called_once_with("access-token")

@freeze_time("2021-01-01 00:00:00")
def test_get_sid_nothing_stored(self):
self.mock_oauth_storage.fetch_access_token.return_value = None

response = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
}

self.mock_oauth_handler.get_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)

oauth_wrapper.get_sid("my-code", "my-device")
self.mock_oauth_storage.store_access_token.assert_called_once_with(response)
token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1))
self.oauth_wrapper.get_sid(token)
self.mock_oauth_handler.get_sid.assert_called_once_with("access-token")

@freeze_time("2021-01-01 00:00:01")
def test_get_sid_expired_access_token(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 0).timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}

response = {
"expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(),
"access_token": "access-token",
}

token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 0))
response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0))
self.mock_oauth_handler.get_new_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)
oauth_wrapper.get_sid("my-code", "my-device")
self.oauth_wrapper.get_sid(token)
self.mock_oauth_handler.get_new_access_token.assert_called_once_with(
"refresh-token"
)
Expand All @@ -154,25 +111,28 @@ def test_get_sid_expired_access_token(self):

@freeze_time("2021-01-01 00:00:00")
def test_invalid_token_refresh(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}

token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1))
self.mock_oauth_storage.fetch_access_token.return_value = token

self.mock_oauth_handler.get_sid.side_effect = [
HTTPError("Something went wrong"),
"sid",
]

response = {
"expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(),
"access_token": "access-token",
}
response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0))

self.mock_oauth_handler.get_new_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)
oauth_wrapper.get_sid("my-code", "my-device")
oauth_wrapper.get_sid(token)

self.assertEqual(self.mock_oauth_handler.get_sid.call_count, 2)

@staticmethod
def _get_token(date_time: datetime.datetime) -> dict:
return {
"expires_at": date_time.timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}

0 comments on commit 36d4a7a

Please sign in to comment.