Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
Base implementation of ClientDatabase classes (CZ-NIC#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpazderka authored and andrewkrug committed Jun 6, 2019
1 parent 2e0cb5b commit 94c2276
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 36 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on the [KeepAChangeLog] project.
- [#469] Allow endpoints to have query parts
- [#443] Ability to specify additional supported claims for oic.Provider
- [#134] Added method kwarg to registration_endpoint that enables the client to read/modify registration
- [#478] Addedd base-class for Client databases ``oic.utils.clientdb.BaseClientDatabase``

### Changed
- [#134] ``l_registration_enpoint`` has been deprecated, use ``create_registration`` instead
Expand Down Expand Up @@ -42,6 +43,7 @@ The format is based on the [KeepAChangeLog] project.
[#471]: https://github.com/OpenIDC/pyoidc/issues/471
[#352]: https://github.com/OpenIDC/pyoidc/issues/352
[#475]: https://github.com/OpenIDC/pyoidc/issues/475
[#478]: https://github.com/OpenIDC/pyoidc/issues/478

## 0.12.0 [2017-09-25]

Expand Down
1 change: 0 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ url = "https://pypi.python.org/simple"

[dev-packages]
Sphinx = "*"
httpretty = "*"
isort = "*"
mock = "*"
pylama = "*"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_tests(self):
errno = pytest.main(self.test_args)
sys.exit(errno)

tests_requires = ['responses', 'testfixtures', 'pytest', 'mock', 'freezegun', 'httpretty']
tests_requires = ['responses', 'testfixtures', 'pytest', 'mock', 'freezegun']

# Python 2.7 and later ship with importlib and argparse
if sys.version_info[0] == 2 and sys.version_info[1] == 6:
Expand Down
5 changes: 5 additions & 0 deletions src/oic/oauth2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import sys
import traceback
import warnings
from http.cookies import SimpleCookie

from six import PY2
Expand Down Expand Up @@ -43,6 +44,7 @@
from oic.utils.authn.user import NoSuchAuthentication
from oic.utils.authn.user import TamperAllert
from oic.utils.authn.user import ToOld
from oic.utils.clientdb import BaseClientDatabase
from oic.utils.http_util import OAUTH2_NOCACHE_HEADERS
from oic.utils.http_util import BadRequest
from oic.utils.http_util import CookieDealer
Expand Down Expand Up @@ -158,6 +160,9 @@ def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
baseurl='', server_cls=Server, client_cert=None):
self.name = name
self.sdb = sdb
if not isinstance(cdb, BaseClientDatabase):
warnings.warn('ClientDatabase should be an instance of '
'oic.utils.clientdb.BaseClientDatabase to ensure proper API.')
self.cdb = cdb
self.server = server_cls(verify_ssl=verify_ssl, client_cert=client_cert)

Expand Down
5 changes: 4 additions & 1 deletion src/oic/utils/client_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from oic import rndstr
from oic.oic.provider import secret
from oic.utils.clientdb import BaseClientDatabase

__author__ = 'rolandh'

Expand Down Expand Up @@ -43,7 +44,9 @@ def pack_redirect_uri(redirect_uris):
return ruri


class CDB(object):
class CDB(BaseClientDatabase):
"""Implementation of ClientDatabase with shelve."""

def __init__(self, filename):
self.cdb = shelve.open(filename, writeback=True)
self.seed = rndstr(32).encode("utf-8")
Expand Down
116 changes: 107 additions & 9 deletions src/oic/utils/clientdb.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,120 @@
import json
"""Client managament databases."""
from abc import ABCMeta
from abc import abstractmethod

import requests
from six import with_metaclass
from six.moves.urllib.parse import quote
from six.moves.urllib.parse import urljoin

from oic.oauth2.exception import NoClientInfoReceivedError


class MDQClient(object):
class BaseClientDatabase(with_metaclass(ABCMeta)):
"""
Base implementation for Client management database.
Custom Client databases should derive from this class.
They must implement the following methods:
* ``__getitem__(self, key)``
* ``__setitem__(self, key, value)``
* ``__delitem__(self, key)``
* ``keys(self)``
* ``items(self)``
"""

def __init__(self):
"""Perform initialization of storage. Derived classes may override."""

@abstractmethod
def __getitem__(self, key):
"""Retrieve an item by a key. Raises KeyError if item not found."""
pass # pragma: no cover

def get(self, key, default=None):
"""Retrieve an item by a key. Return default if not found."""
try:
return self[key]
except KeyError:
return default

@abstractmethod
def __setitem__(self, key, value):
"""Set key with value."""
pass # pragma: no cover

@abstractmethod
def __delitem__(self, key):
"""Remove key from database."""
pass # pragma: no cover

def __contains__(self, key):
"""Return True if key is contained in the database."""
try:
self[key]
except KeyError:
return False
else:
return True

@abstractmethod
def keys(self):
"""Return all contained keys."""
pass # pragma: no cover

@abstractmethod
def items(self):
"""Return list of all contained items."""
pass # pragma: no cover

def __len__(self):
"""Return number of contained keys."""
return len(self.keys())


class MDQClient(BaseClientDatabase):
"""Implementation of remote client database."""

def __init__(self, url):
"""Set the remote storage url."""
self.url = url
self.headers = {'Accept': 'application/json', 'Accept-Encoding': 'gzip'}

def __getitem__(self, item):
mdx_url = "{}/entities/{}".format(self.url, quote(item, safe=''))
response = requests.request("GET", mdx_url,
headers={'Accept': 'application/json',
'Accept-Encoding': 'gzip'})
"""Retrieve a single entity."""
mdx_url = urljoin(self.url, 'entities/{}'.format(quote(item, safe='')))
response = requests.get(mdx_url, headers=self.headers)
if response.status_code == 200:
return response.json()
else:
raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason))

def __setitem__(self, item, value):
"""Remote management is readonly."""
raise RuntimeError('MDQClient is readonly.')

def __delitem__(self, item):
""""Remote management is readonly."""
raise RuntimeError('MDQClient is readonly.')

def keys(self):
"""Get all registered entitites."""
mdx_url = urljoin(self.url, 'entities')
response = requests.get(mdx_url, headers=self.headers)
if response.status_code == 200:
return [item['client_id'] for item in response.json()]
else:
raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason))

def items(self):
"""Geting all registered entities."""
mdx_url = urljoin(self.url, 'entities')
response = requests.get(mdx_url, headers=self.headers)
if response.status_code == 200:
return json.loads(response.text)
return response.json()
else:
raise NoClientInfoReceivedError("{} {}".format(response.status_code,
response.reason))
raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason))


# Dictionary can be used as a ClientDatabase
BaseClientDatabase.register(dict)
142 changes: 119 additions & 23 deletions tests/test_clientdb.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,141 @@
# pylint: disable=missing-docstring

"""Unittests for ClientDatabases."""
import json
from operator import itemgetter

import httpretty
import pytest
import responses

from oic.oauth2.exception import NoClientInfoReceivedError
from oic.utils.clientdb import BaseClientDatabase
from oic.utils.clientdb import MDQClient


class TestBaseClientDatabase(object):

class DictClientDatabase(BaseClientDatabase):
"""Test implementation."""

def __init__(self):
self.db = {}

def __getitem__(self, key):
return self.db[key]

def __setitem__(self, key, value):
self.db[key] = value

def __delitem__(self, key):
del self.db[key]

def keys(self):
return self.db.keys()

def items(self):
return self.db.items()

def test_get_missing(self):
cdb = self.DictClientDatabase()
assert cdb.get('client') is None
assert cdb.get('client', 'spam') == 'spam'

def test_get(self):
cdb = self.DictClientDatabase()
cdb['client'] = 'value'

assert cdb.get('client', 'spam') == 'value'

def test_contains(self):
cdb = self.DictClientDatabase()
cdb['client1'] = 'spam'

assert 'client1' in cdb
assert 'client2' not in cdb

def test_len(self):
cdb = self.DictClientDatabase()
cdb['client1'] = 'spam'
cdb['client2'] = 'eggs'

assert len(cdb) == 2


class TestMDQClient(object):
URL = "http://localhost/mdx"
CLIENT_ID = "client1"
MDX_URL = URL + "/entities/" + CLIENT_ID
"""Tests for MDQClient."""

URL = "http://localhost/mdx/"

@pytest.fixture(autouse=True)
def create_client(self):
self.md = MDQClient(TestMDQClient.URL)

@httpretty.activate
def test_get_existing_client(self):
metadata = {"client_id": TestMDQClient.CLIENT_ID,
metadata = {"client_id": 'client1',
"client_secret": "abcd1234",
"redirect_uris": ["http://example.com/rp/authz_cb"]}
response_body = json.dumps(metadata)

httpretty.register_uri(httpretty.GET,
TestMDQClient.MDX_URL.format(
client_id=TestMDQClient.CLIENT_ID),
body=response_body,
content_type="application/json")
url = TestMDQClient.URL + 'entities/client1'
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, body=json.dumps(metadata))
result = self.md['client1']

result = self.md[TestMDQClient.CLIENT_ID]
assert metadata == result

@httpretty.activate
def test_get_non_existing_client(self):
httpretty.register_uri(httpretty.GET,
TestMDQClient.MDX_URL.format(
client_id=TestMDQClient.CLIENT_ID),
status=404)
url = TestMDQClient.URL + 'entities/client1'
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, status=404)
with pytest.raises(NoClientInfoReceivedError):
self.md['client1']

def test_keys(self):
url = TestMDQClient.URL + 'entities'
metadata = [
{'client_id': 'client1',
'client_secret': 'secret',
'redirect_uris': ['http://example.com']},
{'client_id': 'client2',
'client_secret': 'secret',
'redirect_uris': ['http://ecample2.com']},
]
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, body=json.dumps(metadata))
result = self.md.keys()

assert {'client1', 'client2'} == set(result)

def test_keys_error(self):
url = TestMDQClient.URL + 'entities'
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, status=404)
with pytest.raises(NoClientInfoReceivedError):
self.md.keys()

def test_items(self):
url = TestMDQClient.URL + 'entities'
metadata = [
{'client_id': 'client1',
'client_secret': 'secret',
'redirect_uris': ['http://example.com']},
{'client_id': 'client2',
'client_secret': 'secret',
'redirect_uris': ['http://ecample2.com']},
]
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, body=json.dumps(metadata))
result = self.md.items()

assert sorted(metadata, key=itemgetter('client_id')) == sorted(result, key=itemgetter('client_id'))

def test_items_errors(self):
url = TestMDQClient.URL + 'entities'
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, status=404)
with pytest.raises(NoClientInfoReceivedError):
self.md.items()

def test_setitem(self):
with pytest.raises(RuntimeError):
self.md['client'] = 'foo'

with pytest.raises(NoClientInfoReceivedError):
self.md[TestMDQClient.CLIENT_ID] # pylint: disable=pointless-statement
def test_delitem(self):
with pytest.raises(RuntimeError):
del self.md['client']
2 changes: 1 addition & 1 deletion tests/test_oic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def test_provider_key_setup(self, tmpdir, session_db_factory):
# windows, so we throw them away and add a '.' for a local path.
path = "." + os.path.splitdrive(path)[1].replace(os.path.sep, '/')

provider = Provider("pyoicserv", session_db_factory(SERVER_INFO["issuer"]), None,
provider = Provider("pyoicserv", session_db_factory(SERVER_INFO["issuer"]), {},
None, None, None, None, None)
provider.baseurl = "http://www.example.com"
provider.key_setup(path, path, sig={"format": "jwk", "alg": "RSA"})
Expand Down

0 comments on commit 94c2276

Please sign in to comment.