Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework OAuth aspects #136

Merged
merged 2 commits into from
Feb 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 36 additions & 45 deletions iolite_client/oauth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def get_sid_query(access_token: str) -> str:
}
)

@staticmethod
def add_expires_at(token: dict) -> dict:
expires_at = time.time() + token["expires_in"]
token.update({"expires_at": expires_at})
del token["expires_in"]
return token


class OAuthHandler:
def __init__(self, username: str, password: str):
Expand All @@ -62,7 +69,7 @@ def get_access_token(self, code: str, name: str) -> dict:
f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password)
)
response.raise_for_status()
return json.loads(response.text)
return OAuthHandlerHelper.add_expires_at(json.loads(response.text))

def get_new_access_token(self, refresh_token: str) -> dict:
"""
Expand All @@ -75,7 +82,7 @@ def get_new_access_token(self, refresh_token: str) -> dict:
f"{BASE_URL}/ui/token?{query}", auth=(self.username, self.password)
)
response.raise_for_status()
return json.loads(response.text)
return OAuthHandlerHelper.add_expires_at(json.loads(response.text))

def get_sid(self, access_token: str) -> str:
"""
Expand Down Expand Up @@ -112,7 +119,7 @@ async def get_access_token(self, code: str, name: str) -> dict:
auth=aiohttp.BasicAuth(self.username, self.password),
)
response.raise_for_status()
return await response.json()
return OAuthHandlerHelper.add_expires_at(await response.json())

async def get_new_access_token(self, refresh_token: str) -> dict:
"""
Expand All @@ -126,7 +133,7 @@ async def get_new_access_token(self, refresh_token: str) -> dict:
auth=aiohttp.BasicAuth(self.username, self.password),
)
response.raise_for_status()
return await response.json()
return OAuthHandlerHelper.add_expires_at(await response.json())

async def get_sid(self, access_token: str) -> str:
"""
Expand Down Expand Up @@ -165,8 +172,6 @@ def __init__(self, path: str):
self.path = path

def store_access_token(self, payload: dict):
expires_at = time.time() + payload["expires_in"]
payload.update({"expires_at": expires_at})
self.__store("access_token", payload)

def fetch_access_token(self) -> Optional[dict]:
Expand Down Expand Up @@ -200,39 +205,29 @@ def __init__(
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

def get_sid(self, code: str, name: str) -> str:
def get_sid(self, token: dict) -> str:
"""
Get SID by providing the initial pairing code and the device name you would like to register.
Get SID by access_token.

:param code: The code provided in the QR code
:param name: The name of the device you want to register
:param token: The access token
:return:
"""
access_token = self.oauth_storage.fetch_access_token()

if access_token is None:
logger.debug("No token, requesting")
access_token = self.oauth_handler.get_access_token(code, name)
self.oauth_storage.store_access_token(access_token)

expires_at = access_token["expires_at"]

if expires_at < time.time():
if token["expires_at"] < time.time():
logger.debug("Token expired, refreshing")
token = self._refresh_access_token(access_token)
access_token = self._refresh_access_token(token)
else:
token = access_token["access_token"]
access_token = token["access_token"]

try:
return self.oauth_handler.get_sid(token)
return self.oauth_handler.get_sid(access_token)
except requests.exceptions.HTTPError as e:
logger.debug(f"Invalid token, attempt refresh: {e}")
token = self._refresh_access_token(access_token)
return self.oauth_handler.get_sid(token)
access_token = self._refresh_access_token(token)
return self.oauth_handler.get_sid(access_token)

def _refresh_access_token(self, access_token):
def _refresh_access_token(self, token: dict) -> str:
refreshed_token = self.oauth_handler.get_new_access_token(
access_token["refresh_token"]
token["refresh_token"]
)
self.oauth_storage.store_access_token(refreshed_token)
return refreshed_token["access_token"]
Expand All @@ -247,37 +242,33 @@ def __init__(
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

async def get_sid(self, code: str, name: str) -> str:
access_token = await self.oauth_storage.fetch_access_token()

if access_token is None:
logger.debug("No token, requesting")
access_token = await self.oauth_handler.get_access_token(code, name)
await self.oauth_storage.store_access_token(access_token)
async def get_sid(self, token: dict) -> str:
"""
Get SID by token.

if access_token["expires_at"] < time.time():
:param token: The token
:return:
"""
if token["expires_at"] < time.time():
logger.debug("Token expired, refreshing")
token = await self._refresh_token(access_token)
access_token = await self._refresh_token(token)
else:
token = access_token["access_token"]
access_token = token["access_token"]

logger.debug("Fetched access token")

try:
return await self.oauth_handler.get_sid(token)
return await self.oauth_handler.get_sid(access_token)
except BaseException as e:
logger.debug(f"Invalid token, attempt refresh: {e}")
token = await self._refresh_token(access_token)
return await self.oauth_handler.get_sid(token)
access_token = await self._refresh_token(token)
return await self.oauth_handler.get_sid(access_token)

async def _refresh_token(self, access_token: dict) -> str:
async def _refresh_token(self, token: dict) -> str:
"""Refresh token."""
refreshed_token = await self.oauth_handler.get_new_access_token(
access_token["refresh_token"]
token["refresh_token"]
)
expires_at = time.time() + refreshed_token["expires_in"]
refreshed_token.update({"expires_at": expires_at})
del refreshed_token["expires_in"]
await self.oauth_storage.store_access_token(refreshed_token)

return refreshed_token["access_token"]
21 changes: 20 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ responses = "*"
pytest-socket = "*"
freezegun = "*"
aioresponses = "^0.7.2"
pytest-asyncio = "^0.18.1"

[tool.poetry.extras]
dev = ["environs"]
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
addopts = --disable-socket
asyncio_mode = strict
8 changes: 7 additions & 1 deletion scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
oauth_storage = LocalOAuthStorage(".")
oauth_handler = OAuthHandler(USERNAME, PASSWORD)
oauth_wrapper = OAuthWrapper(oauth_handler, oauth_storage)
sid = oauth_wrapper.get_sid(CODE, NAME)

access_token = oauth_storage.fetch_access_token()
if not access_token:
access_token = oauth_handler.get_access_token(CODE, NAME)
oauth_storage.store_access_token(access_token)

sid = oauth_wrapper.get_sid(access_token)

print("------------------")
print(f"URL: https://remote.iolite.de/ui/?SID={sid}")
Expand Down
110 changes: 35 additions & 75 deletions test/test_oauth_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import datetime
import json
import unittest
Expand Down Expand Up @@ -46,29 +45,20 @@ def test_get_access_token_valid(self):


class AsyncOAuthHandlerTest(unittest.TestCase):
@staticmethod
def _run_async(callback):
loop = asyncio.get_event_loop()
loop.run_until_complete(callback())

@pytest.mark.asyncio
@pytest.mark.enable_socket
def test_get_access_token_invalid_credentials(self):

async def test_get_access_token_invalid_credentials(self):
with aioresponses() as m:
m.post("https://remote.iolite.de/ui/token", status=403)

async def async_callback():
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
with self.assertRaises(
aiohttp.client_exceptions.ClientConnectionError
):
await oauth_handler.get_access_token("dodgy-code", "my-device")

self._run_async(async_callback)
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
with self.assertRaises(aiohttp.client_exceptions.ClientConnectionError):
await oauth_handler.get_access_token("dodgy-code", "my-device")

@pytest.mark.asyncio
@pytest.mark.enable_socket
def test_get_access_token_valid(self):
async def test_get_access_token_valid(self):
with aioresponses() as m:
query = OAuthHandlerHelper.get_access_token_query("real-code", "my-device")
m.post(
Expand All @@ -84,68 +74,35 @@ def test_get_access_token_valid(self):
),
)

async def async_callback():
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
response = await oauth_handler.get_access_token(
"real-code", "my-device"
)
self.assertIsInstance(response, dict)

self._run_async(async_callback)
async with aiohttp.ClientSession() as web_session:
oauth_handler = AsyncOAuthHandler("user", "password", web_session)
response = await oauth_handler.get_access_token(
"real-code", "my-device"
)
self.assertIsInstance(response, dict)


class OAuthWrapperTest(unittest.TestCase):
def setUp(self) -> None:
self.mock_oauth_handler = Mock()
self.mock_oauth_storage = Mock()
self.oauth_wrapper = OAuthWrapper(
self.mock_oauth_handler, self.mock_oauth_storage
)

@freeze_time("2021-01-01 00:00:00")
def test_get_sid_valid_access_token(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
}

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)

oauth_wrapper.get_sid("my-code", "my-device")
self.mock_oauth_handler.get_sid.assert_called_once_with("access-token")

@freeze_time("2021-01-01 00:00:00")
def test_get_sid_nothing_stored(self):
self.mock_oauth_storage.fetch_access_token.return_value = None

response = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
}

self.mock_oauth_handler.get_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)

oauth_wrapper.get_sid("my-code", "my-device")
self.mock_oauth_storage.store_access_token.assert_called_once_with(response)
token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1))
self.oauth_wrapper.get_sid(token)
self.mock_oauth_handler.get_sid.assert_called_once_with("access-token")

@freeze_time("2021-01-01 00:00:01")
def test_get_sid_expired_access_token(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 0).timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}

response = {
"expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(),
"access_token": "access-token",
}

token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 0))
response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0))
self.mock_oauth_handler.get_new_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)
oauth_wrapper.get_sid("my-code", "my-device")
self.oauth_wrapper.get_sid(token)
self.mock_oauth_handler.get_new_access_token.assert_called_once_with(
"refresh-token"
)
Expand All @@ -154,25 +111,28 @@ def test_get_sid_expired_access_token(self):

@freeze_time("2021-01-01 00:00:00")
def test_invalid_token_refresh(self):
self.mock_oauth_storage.fetch_access_token.return_value = {
"expires_at": datetime.datetime(2021, 1, 1, 0, 0, 1).timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}

token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1))
self.mock_oauth_storage.fetch_access_token.return_value = token

self.mock_oauth_handler.get_sid.side_effect = [
HTTPError("Something went wrong"),
"sid",
]

response = {
"expires_at": datetime.datetime(2021, 1, 10, 0, 0, 0).timestamp(),
"access_token": "access-token",
}
response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0))

self.mock_oauth_handler.get_new_access_token.return_value = response

oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage)
oauth_wrapper.get_sid("my-code", "my-device")
oauth_wrapper.get_sid(token)

self.assertEqual(self.mock_oauth_handler.get_sid.call_count, 2)

@staticmethod
def _get_token(date_time: datetime.datetime) -> dict:
return {
"expires_at": date_time.timestamp(),
"access_token": "access-token",
"refresh_token": "refresh-token",
}