Skip to content

Commit

Permalink
Wants the original non-parsed JWT and not an IDToken instance.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohe authored and tpazderka committed Nov 4, 2019
1 parent ddfd8d1 commit f12a2b7
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 21 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ The format is based on the [KeepAChangeLog] project.

## Unreleased

### Fixed
- [#708] Wants the original non-parsed JWT and not an IDToken instance.

[#708]: https://github.com/OpenIDC/pyoidc/pull/708

## 1.1.0 [2019-10-25]

### Changed
Expand Down
3 changes: 2 additions & 1 deletion src/oic/oauth2/grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, resp=None):
self.refresh_token = None
self.token_type = None # type: Optional[str]
self.replaced = False
self.id_token = None

if resp:
for prop, val in resp.items():
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, exp_in=600, resp=None, seed=""):
self.grant_expiration_time = 0
self.exp_in = exp_in
self.seed = seed
self.tokens = []
self.tokens = [] # type: List[Token]
self.id_token = None
self.code = None # type: Optional[str]
if resp:
Expand Down
1 change: 0 additions & 1 deletion src/oic/oauth2/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def __init__(self, **kwargs):
self.jwe_header = None
self.from_dict(kwargs)
self.verify_ssl = True
self.raw_id_token = None

def __iter__(self):
return iter(self._dict)
Expand Down
8 changes: 4 additions & 4 deletions src/oic/oic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _get_id_token(self, **kwargs):
if not flag:
break
if token.id_token:
return token.id_token
return token.id_token.jwt

return None

Expand Down Expand Up @@ -591,11 +591,11 @@ def _id_token_based(self, request, request_args=None, extra_args=None, **kwargs)
if _prop in request_args:
pass
else:
id_token = self._get_id_token(**kwargs)
if id_token is None:
raw_id_token = self._get_id_token(**kwargs)
if raw_id_token is None:
raise MissingParameter("No valid id token available")

request_args[_prop] = id_token
request_args[_prop] = raw_id_token

return self.construct_request(request, request_args, extra_args)

Expand Down
5 changes: 0 additions & 5 deletions src/oic/oic/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ class AccessTokenResponse(message.AccessTokenResponse):
def verify(self, **kwargs):
super().verify(**kwargs)
if "id_token" in self:
self.raw_id_token = self["id_token"]
# replace the JWT with the verified IdToken instance
self["id_token"] = verify_id_token(self, **kwargs)

Expand Down Expand Up @@ -381,7 +380,6 @@ def verify(self, **kwargs):
return False

if "id_token" in self:
self.raw_id_token = self["id_token"]
self["id_token"] = verify_id_token(self, check_hash=True, **kwargs)

if "access_token" in self:
Expand Down Expand Up @@ -797,7 +795,6 @@ class RefreshSessionRequest(StateFullMessage):
def verify(self, **kwargs):
super(RefreshSessionRequest, self).verify(**kwargs)
if "id_token" in self:
self.raw_id_token = self["id_token"]
self["id_token"] = verify_id_token(self, check_hash=True, **kwargs)


Expand All @@ -808,7 +805,6 @@ class RefreshSessionResponse(StateFullMessage):
def verify(self, **kwargs):
super(RefreshSessionResponse, self).verify(**kwargs)
if "id_token" in self:
self.raw_id_token = self["id_token"]
self["id_token"] = verify_id_token(self, check_hash=True, **kwargs)


Expand All @@ -818,7 +814,6 @@ class CheckSessionRequest(Message):
def verify(self, **kwargs):
super(CheckSessionRequest, self).verify(**kwargs)
if "id_token" in self:
self.raw_id_token = self["id_token"]
self["id_token"] = verify_id_token(self, check_hash=True, **kwargs)


Expand Down
1 change: 1 addition & 0 deletions tests/test_grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_access_token(self):
"refresh_token",
"scope",
"replaced",
"id_token",
],
)

Expand Down
60 changes: 50 additions & 10 deletions tests/test_oic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from oic.exception import RegistrationError
from oic.oauth2.exception import OtherError
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
from oic.oauth2.message import MessageTuple
from oic.oic import DEF_SIGN_ALG
from oic.oic import Client
Expand All @@ -22,7 +23,6 @@
from oic.oic import Token
from oic.oic import scope2claims
from oic.oic.message import SCOPE2CLAIMS
from oic.oic.message import SINGLE_OPTIONAL_STRING
from oic.oic.message import AccessTokenRequest
from oic.oic.message import AccessTokenResponse
from oic.oic.message import AuthorizationRequest
Expand Down Expand Up @@ -57,8 +57,8 @@
)

BASE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/keys"))
_key = rsa_load(os.path.join(BASE_PATH, "rsa.key"))
KC_RSA = KeyBundle({"key": _key, "kty": "RSA", "use": "sig"})
RSA_KEY = rsa_load(os.path.join(BASE_PATH, "rsa.key"))
KC_RSA = KeyBundle({"key": RSA_KEY, "kty": "RSA", "use": "sig"})

KEYJ = KeyJar()
KEYJ[""] = [KC_RSA, KC_SYM_S]
Expand Down Expand Up @@ -182,7 +182,6 @@ def test_access_token_request(self):
assert _eq(resp.keys(), ["token_type", "state", "access_token", "scope"])

def test_access_token_request_with_custom_response_class(self):

# AccessTokenResponse wrapper class
class AccessTokenResponseWrapper(AccessTokenResponse):
c_param = AccessTokenResponse.c_param.copy()
Expand Down Expand Up @@ -595,14 +594,24 @@ def test_construct_CheckSessionRequest_2(self):
self.client.grant["foo"].grant_expiration_time = int(time.time() + 60)
self.client.grant["foo"].code = "access_code"

# Need a proper ID Token
self.client.keyjar.add_kb(IDTOKEN["iss"], KC_SYM_S)
_sig_key = self.client.keyjar.get_signing_key("oct", IDTOKEN["iss"])
_signed_jwt = IDTOKEN.to_jwt(_sig_key, algorithm="HS256")

resp = AccessTokenResponse(
id_token="id_id_id_id", access_token="access", scope=["openid"]
id_token=_signed_jwt,
access_token="access",
scope=["openid"],
token_type="bearer",
)

assert resp.verify(keyjar=self.client.keyjar)

self.client.grant["foo"].tokens.append(Token(resp))

csr = self.client.construct_CheckSessionRequest(state="foo", scope=["openid"])
assert csr["id_token"] == "id_id_id_id"
assert csr["id_token"] == _signed_jwt

def test_construct_RegistrationRequest(self):
request_args = {
Expand Down Expand Up @@ -632,10 +641,20 @@ def test_construct_EndSessionRequest_kwargs_state(self):
self.client.grant["foo"].grant_expiration_time = int(time.time() + 60)
self.client.grant["foo"].code = "access_code"

# Need a proper ID Token
self.client.keyjar.add_kb(IDTOKEN["iss"], KC_SYM_S)
_sig_key = self.client.keyjar.get_signing_key("oct", IDTOKEN["iss"])
_signed_jwt = IDTOKEN.to_jwt(_sig_key, algorithm="HS256")

resp = AccessTokenResponse(
id_token="id_id_id_id", access_token="access", scope=["openid"]
id_token=_signed_jwt,
access_token="access",
scope=["openid"],
token_type="bearer",
)

assert resp.verify(keyjar=self.client.keyjar)

self.client.grant["foo"].tokens.append(Token(resp))

# state only in kwargs
Expand All @@ -648,10 +667,21 @@ def test_construct_EndSessionRequest_reqargs_state(self):
self.client.grant["foo"].grant_expiration_time = int(time.time()) + 60
self.client.grant["foo"].code = "access_code"

# Need a proper ID Token
self.client.keyjar.add_kb(IDTOKEN["iss"], KC_SYM_S)
_sig_key = self.client.keyjar.get_signing_key("oct", IDTOKEN["iss"])
_signed_jwt = IDTOKEN.to_jwt(_sig_key, algorithm="HS256")

resp = AccessTokenResponse(
id_token="id_id_id_id", access_token="access", scope=["openid"]
id_token=_signed_jwt,
access_token="access",
scope=["openid"],
token_type="bearer",
)

# Need to do this to get things in place
assert resp.verify(keyjar=self.client.keyjar)

self.client.grant["foo"].tokens.append(Token(resp))

# state only in request_args
Expand All @@ -664,10 +694,20 @@ def test_construct_EndSessionRequest_kwargs_and_reqargs_state(self):
self.client.grant["foo"].grant_expiration_time = int(time.time()) + 60
self.client.grant["foo"].code = "access_code"

# Need a proper ID Token
self.client.keyjar.add_kb(IDTOKEN["iss"], KC_SYM_S)
_sig_key = self.client.keyjar.get_signing_key("oct", IDTOKEN["iss"])
_signed_jwt = IDTOKEN.to_jwt(_sig_key, algorithm="HS256")

resp = AccessTokenResponse(
id_token="id_id_id_id", access_token="access", scope=["openid"]
id_token=_signed_jwt,
access_token="access",
scope=["openid"],
token_type="bearer",
)

assert resp.verify(keyjar=self.client.keyjar)

self.client.grant["foo"].tokens.append(Token(resp))

# state both in request_args and kwargs
Expand Down Expand Up @@ -739,7 +779,7 @@ def test_userinfo_request_post(self):
}

def test_sign_enc_request(self):
KC_RSA_ENC = KeyBundle({"key": _key, "kty": "RSA", "use": "enc"})
KC_RSA_ENC = KeyBundle({"key": RSA_KEY, "kty": "RSA", "use": "enc"})
self.client.keyjar["test_provider"] = [KC_RSA_ENC]

request_args = {
Expand Down

0 comments on commit f12a2b7

Please sign in to comment.