Skip to content

Commit

Permalink
Merge pull request #456 from sebadob/simplify-post-authorize-code
Browse files Browse the repository at this point in the history
Simplify post authorize code
  • Loading branch information
sebadob committed Jun 7, 2024
2 parents 9f85c77 + e79f362 commit af0db9d
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 222 deletions.
10 changes: 10 additions & 0 deletions rauthy-common/src/error_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,13 @@ impl From<header::InvalidHeaderValue> for ErrorResponse {
)
}
}

impl From<std::fmt::Error> for ErrorResponse {
fn from(value: std::fmt::Error) -> Self {
trace!("{:?}", value);
ErrorResponse::new(
ErrorResponseType::Internal,
format!("fmt error: {:?}", value),
)
}
}
8 changes: 3 additions & 5 deletions rauthy-handlers/src/auth_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ pub async fn post_provider_callback(
let (auth_step, cookie) =
AuthProviderCallback::login_finish(&data, &req, &payload, session.clone()).await?;

let (mut resp, _) = map_auth_step(auth_step, &req)
.await
.map_err(|(err, _)| err)?;
let mut resp = map_auth_step(auth_step, &req).await?;
resp.add_cookie(&cookie).map_err(|err| {
ErrorResponse::new(
ErrorResponseType::Internal,
Expand Down Expand Up @@ -411,7 +409,7 @@ pub async fn put_provider_img(
None => {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"content_type is missing".to_string(),
"content_type is missing",
));
}
}
Expand Down Expand Up @@ -466,7 +464,7 @@ pub async fn post_provider_link(
if user.auth_provider_id.is_some() {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"user is already federated".to_string(),
"user is already federated",
));
}

Expand Down
19 changes: 8 additions & 11 deletions rauthy-handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub async fn map_auth_step(
req: &HttpRequest,
// the bool for Ok() is true is the password has been hashed
// the bool for Err() means if we need to add a login delay (and none otherwise for better UX)
) -> Result<(HttpResponse, bool), (ErrorResponse, bool)> {
) -> Result<HttpResponse, ErrorResponse> {
// we will only get here after a successful login -> always return logged-in header
let fed_cm_header = FedCMLoginStatus::LoggedIn.as_header_pair();

Expand All @@ -59,7 +59,7 @@ pub async fn map_auth_step(
if let Some((name, value)) = res.header_origin {
resp.headers_mut().insert(name, value);
}
Ok((resp, res.has_password_been_hashed))
Ok(resp)
}

AuthStep::AwaitWebauthn(res) => {
Expand All @@ -82,23 +82,20 @@ pub async fn map_auth_step(
WebauthnCookie::parse_validate(&ApiCookie::from_req(req, COOKIE_MFA))
{
if mfa_cookie.email != res.email {
add_req_mfa_cookie(&mut resp, res.email.clone()).map_err(|err| (err, true))?;
add_req_mfa_cookie(&mut resp, res.email.clone())?;
}
} else {
add_req_mfa_cookie(&mut resp, res.email.clone()).map_err(|err| (err, true))?;
add_req_mfa_cookie(&mut resp, res.email.clone())?;
}

Ok((resp, res.has_password_been_hashed))
Ok(resp)
}

AuthStep::ProviderLink => {
// TODO generate a new event type in this case?
Ok((
HttpResponse::NoContent()
.insert_header(fed_cm_header)
.finish(),
false,
))
Ok(HttpResponse::NoContent()
.insert_header(fed_cm_header)
.finish())
}
}
}
Expand Down
76 changes: 64 additions & 12 deletions rauthy-handlers/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,55 @@ pub async fn post_authorize(
let start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();

let session = principal.get_session()?;
let res = match auth::authorize(&data, &req, payload.into_inner(), session.clone()).await {

let mut has_password_been_hashed = false;
let mut add_login_delay = true;
let mut user_needs_mfa = false;

let res = match auth::authorize(
&data,
&req,
payload.into_inner(),
session.clone(),
&mut has_password_been_hashed,
&mut add_login_delay,
&mut user_needs_mfa,
)
.await
{
Ok(auth_step) => map_auth_step(auth_step, &req).await,
Err(err) => Err(err),
Err(err) => {
debug!("{:?}", err);
// We always must return the exact same error type, no matter what the actual error is,
// to prevent information enumeration. The only exception is when the user needs to add
// a passkey to the account while having given the correct credentials. In that case,
// we return the original error to be able to display the info message in the UI.
if user_needs_mfa {
// in this case, we can return directly without any login delay
return Err(err);
}

let err = Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"Invalid user credentials",
));
if !add_login_delay {
return err;
}
err
}
};

let ip = real_ip_from_req(&req);
auth::handle_login_delay(&data, ip, start, &data.caches.ha_cache_config, res).await
auth::handle_login_delay(
&data,
ip,
start,
&data.caches.ha_cache_config,
res,
has_password_been_hashed,
)
.await
}

/// Immediate login refresh with valid session
Expand Down Expand Up @@ -303,10 +345,7 @@ pub async fn post_authorize_refresh(
let auth_step =
auth::authorize_refresh(&data, session, client, header_origin, req_data.into_inner())
.await?;
map_auth_step(auth_step, &req)
.await
.map(|res| res.0)
.map_err(|err| err.0)
map_auth_step(auth_step, &req).await
}

#[get("/oidc/callback")]
Expand Down Expand Up @@ -845,14 +884,14 @@ pub async fn post_token(
let ip = real_ip_from_req(&req);

if payload.grant_type == GRANT_TYPE_DEVICE_CODE {
// TODO the `urn:ietf:params:oauth:grant-type:device_code` needs
// the `urn:ietf:params:oauth:grant-type:device_code` needs
// a fully customized handling here with customized error response
// to meet the oauth rfc
return Ok(auth::grant_type_device_code(&data, ip, payload.into_inner()).await);
}

let start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let add_login_delay = payload.grant_type == "password";
let has_password_been_hashed = payload.grant_type == "password";

let res = match auth::get_token_set(payload.into_inner(), &data, req).await {
Ok((token_set, headers)) => {
Expand All @@ -861,12 +900,25 @@ pub async fn post_token(
builder.insert_header(h);
}
let resp = builder.json(token_set);
Ok((resp, add_login_delay))
Ok(resp)
}
Err(err) => {
if !has_password_been_hashed {
return Err(err);
}
Err(err)
}
Err(err) => Err((err, add_login_delay)),
};

auth::handle_login_delay(&data, ip, start, &data.caches.ha_cache_config, res).await
auth::handle_login_delay(
&data,
ip,
start,
&data.caches.ha_cache_config,
res,
has_password_been_hashed,
)
.await
}

/// The tokenInfo endpoint for the OIDC standard.
Expand Down
3 changes: 1 addition & 2 deletions rauthy-models/src/entity/app_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ impl LatestAppVersion {
} else {
Err(ErrorResponse::new(
ErrorResponseType::Internal,
"Could not find 'tag_name' in Rauthy App Version lookup response"
.to_string(),
"Could not find 'tag_name' in Rauthy App Version lookup response",
))
}?;

Expand Down
19 changes: 7 additions & 12 deletions rauthy-models/src/entity/auth_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ impl AuthProviderCallback {
let callback_id = ApiCookie::from_req(req, COOKIE_UPSTREAM_CALLBACK).ok_or_else(|| {
ErrorResponse::new(
ErrorResponseType::Forbidden,
"Missing encrypted callback cookie".to_string(),
"Missing encrypted callback cookie",
)
})?;

Expand All @@ -712,7 +712,7 @@ impl AuthProviderCallback {
error!("`state` does not match");
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"`state` does not match".to_string(),
"`state` does not match",
));
}
debug!("callback state is valid");
Expand All @@ -725,7 +725,7 @@ impl AuthProviderCallback {
error!("invalid CSRF token");
return Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"invalid CSRF token".to_string(),
"invalid CSRF token",
));
}
debug!("callback csrf token is valid");
Expand All @@ -739,7 +739,7 @@ impl AuthProviderCallback {
error!("invalid PKCE verifier");
return Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"invalid PKCE verifier".to_string(),
"invalid PKCE verifier",
));
}
debug!("callback pkce verifier is valid");
Expand Down Expand Up @@ -835,10 +835,7 @@ impl AuthProviderCallback {
} else {
let err = "Neither `access_token` nor `id_token` existed";
error!("{}", err);
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
err.to_string(),
));
return Err(ErrorResponse::new(ErrorResponseType::BadRequest, err));
}
}
Err(err) => {
Expand Down Expand Up @@ -871,7 +868,7 @@ impl AuthProviderCallback {
if provider_mfa_login == ProviderMfaLogin::No && !user.has_webauthn_enabled() {
return Err(ErrorResponse::new(
ErrorResponseType::MfaRequired,
"MFA is required for this client".to_string(),
"MFA is required for this client",
));
}
session.set_mfa(data, true).await?;
Expand Down Expand Up @@ -905,12 +902,11 @@ impl AuthProviderCallback {
// location header
let mut loc = format!("{}?code={}", slf.req_redirect_uri, code.id);
if let Some(state) = slf.req_state {
write!(loc, "&state={}", state).expect("`write!` to succeed");
write!(loc, "&state={}", state)?;
};

let auth_step = if user.has_webauthn_enabled() {
let step = AuthStepAwaitWebauthn {
has_password_been_hashed: false,
code: get_rand(48),
header_csrf: Session::get_csrf_header(&session.csrf_token),
header_origin,
Expand All @@ -935,7 +931,6 @@ impl AuthProviderCallback {
AuthStep::AwaitWebauthn(step)
} else {
AuthStep::LoggedIn(AuthStepLoggedIn {
has_password_been_hashed: false,
user_id: user.id,
email: user.email,
header_loc: (header::LOCATION, HeaderValue::from_str(&loc).unwrap()),
Expand Down
4 changes: 2 additions & 2 deletions rauthy-models/src/entity/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ impl Client {
trace!("'code_challenge_method' is missing");
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
String::from("'code_challenge_method' is missing"),
"'code_challenge_method' is missing",
));
}

Expand All @@ -781,7 +781,7 @@ impl Client {
trace!("'code_challenge' not enabled for this client");
Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"'code_challenge' not enabled for this client".to_string(),
"'code_challenge' not enabled for this client",
))
} else {
Ok(())
Expand Down
12 changes: 6 additions & 6 deletions rauthy-models/src/entity/principal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Principal {
if *ADMIN_FORCE_MFA && !self.has_mfa_active() {
return Err(ErrorResponse::new(
ErrorResponseType::MfaRequired,
"Rauthy admin access only allowed with MFA active".to_string(),
"Rauthy admin access only allowed with MFA active",
));
}

Expand Down Expand Up @@ -184,14 +184,14 @@ impl Principal {
trace!("Validating the session failed - was not in auth state");
Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"Unauthorized Session".to_string(),
"Unauthorized Session",
))
}
} else {
trace!("Validating the session failed - no session found");
Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"No valid session".to_string(),
"No valid session",
))
}
}
Expand All @@ -205,14 +205,14 @@ impl Principal {
trace!("Validating the session failed - was not in init or auth state");
Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"Unauthorized Session".to_string(),
"Unauthorized Session",
))
}
} else {
trace!("Validating the session failed - no session found");
Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"No valid session".to_string(),
"No valid session",
))
}
}
Expand All @@ -230,7 +230,7 @@ impl Principal {
trace!("Validating the session failed - was not in init state");
Err(ErrorResponse::new(
ErrorResponseType::Unauthorized,
"Session in Init state mandatory".to_string(),
"Session in Init state mandatory",
))
}
}
Expand Down
2 changes: 1 addition & 1 deletion rauthy-models/src/entity/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ impl Session {
if OffsetDateTime::now_utc().unix_timestamp() > ts {
return Err(ErrorResponse::new(
ErrorResponseType::Forbidden,
"User has expired".to_string(),
"User has expired",
));
} else if ts < self.exp {
self.exp = ts;
Expand Down
Loading

0 comments on commit af0db9d

Please sign in to comment.