From 94d1f99ee3d76b981fd173e1ab810dc548e9e24c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Pazderka?= Date: Fri, 10 Mar 2023 10:19:00 +0100 Subject: [PATCH] Pass timeout to all requests --- oidc_example/rp2/oidc.py | 3 ++- src/oic/utils/authn/user_cas.py | 12 ++++++++++-- src/oic/utils/clientdb.py | 11 ++++++----- src/oic/utils/keyio.py | 4 ++-- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/oidc_example/rp2/oidc.py b/oidc_example/rp2/oidc.py index 6e92177c8..0526cd278 100644 --- a/oidc_example/rp2/oidc.py +++ b/oidc_example/rp2/oidc.py @@ -126,7 +126,8 @@ def begin(self, environ, server_env, start_response, session, key): if client is not None and self.srv_discovery_url: data = {"client_id": client.client_id} resp = requests.get(self.srv_discovery_url + "verifyClientId", - params=data, verify=self.extra["ca_bundle"]) + params=data, verify=self.extra["ca_bundle"], + timeout=10) if not resp.ok and resp.status_code == 400: client = None server_env["OIC_CLIENT"].pop(key, None) diff --git a/src/oic/utils/authn/user_cas.py b/src/oic/utils/authn/user_cas.py index 7f8d1946f..4934c431b 100644 --- a/src/oic/utils/authn/user_cas.py +++ b/src/oic/utils/authn/user_cas.py @@ -41,7 +41,9 @@ class CasAuthnMethod(UserAuthnMethod): # The name for the CAS cookie, containing query parameters and nonce. CONST_CAS_COOKIE = "cascookie" - def __init__(self, srv, cas_server, service_url, return_to, extra_validation=None): + def __init__( + self, srv, cas_server, service_url, return_to, extra_validation=None, timeout=5 + ): """ Construct the class. @@ -51,12 +53,14 @@ def __init__(self, srv, cas_server, service_url, return_to, extra_validation=Non this case the oic server's verify URL. :param return_to: The URL to return to after a successful authentication. + :param timeout: Timeout for requests library. """ UserAuthnMethod.__init__(self, srv) self.cas_server = cas_server self.service_url = service_url self.return_to = return_to self.extra_validation = extra_validation + self.timeout = timeout def create_redirect(self, query): """ @@ -101,7 +105,11 @@ def handle_callback(self, ticket, service_url): :return: Uid if the login was successful otherwise None. """ data = {self.CONST_TICKET: ticket, self.CONST_SERVICE: service_url} - resp = requests.get(self.cas_server + self.CONST_CAS_VERIFY_TICKET, params=data) + resp = requests.get( + self.cas_server + self.CONST_CAS_VERIFY_TICKET, + params=data, + timeout=self.timeout, + ) root = ET.fromstring(resp.content) for l1 in root: if self.CONST_AUTHSUCCESS in l1.tag: diff --git a/src/oic/utils/clientdb.py b/src/oic/utils/clientdb.py index 634474415..24e20c003 100644 --- a/src/oic/utils/clientdb.py +++ b/src/oic/utils/clientdb.py @@ -74,15 +74,16 @@ def __len__(self): class MDQClient(BaseClientDatabase): """Implementation of remote client database.""" - def __init__(self, url): - """Set the remote storage url.""" + def __init__(self, url, timeout=5): + """Set the remote storage url and timeout for requests.""" self.url = url + self.timeout = timeout self.headers = {"Accept": "application/json", "Accept-Encoding": "gzip"} def __getitem__(self, item): """Retrieve a single entity.""" mdx_url = urljoin(self.url, "entities/{}".format(quote(item, safe=""))) - response = requests.get(mdx_url, headers=self.headers) + response = requests.get(mdx_url, headers=self.headers, timeout=self.timeout) if response.status_code == 200: return response.json() else: @@ -101,7 +102,7 @@ def __delitem__(self, item): def keys(self): """Get all registered entitites.""" mdx_url = urljoin(self.url, "entities") - response = requests.get(mdx_url, headers=self.headers) + response = requests.get(mdx_url, headers=self.headers, timeout=self.timeout) if response.status_code == 200: return [item["client_id"] for item in response.json()] else: @@ -112,7 +113,7 @@ def keys(self): def items(self): """Geting all registered entities.""" mdx_url = urljoin(self.url, "entities") - response = requests.get(mdx_url, headers=self.headers) + response = requests.get(mdx_url, headers=self.headers, timeout=self.timeout) if response.status_code == 200: return response.json() else: diff --git a/src/oic/utils/keyio.py b/src/oic/utils/keyio.py index f91c8b196..fd6e139a3 100644 --- a/src/oic/utils/keyio.py +++ b/src/oic/utils/keyio.py @@ -185,13 +185,13 @@ def do_remote(self): if self.source is None: # Nothing to do return False - args = {"verify": self.verify_ssl, "timeout": self.timeout} + args = {"verify": self.verify_ssl} if self.etag: args["headers"] = {"If-None-Match": self.etag} try: logger.debug("KeyBundle fetch keys from: %s", self.source) - r = requests.get(self.source, **args) + r = requests.get(self.source, timeout=self.timeout, **args) except Exception as err: logger.error(err) raise_exception(UpdateFailed, REMOTE_FAILED.format(self.source, str(err)))