diff --git a/axum-login/src/auth.rs b/axum-login/src/auth.rs index 5eb1daf..1c54d11 100644 --- a/axum-login/src/auth.rs +++ b/axum-login/src/auth.rs @@ -160,7 +160,7 @@ where if let Some(role) = role { RangeBounds::contains(self, &role) } else { - role.is_none() + role.is_some() } } } @@ -171,7 +171,7 @@ where pub struct Login { login_url: Option>>, redirect_field_name: Option>>, - role_bounds: Box>, + role_bounds: Option>>, _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData ResBody>, @@ -182,7 +182,10 @@ impl Clone for Login { Self { login_url: self.login_url.clone(), redirect_field_name: self.redirect_field_name.clone(), - role_bounds: dyn_clone::clone_box(&*self.role_bounds), + role_bounds: self + .role_bounds + .as_ref() + .map(|rb| dyn_clone::clone_box(&**rb)), _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData, @@ -205,7 +208,10 @@ where BoxFuture<'static, Result, Response>>; fn authorize(&mut self, mut request: Request) -> Self::Future { - let role_bounds = dyn_clone::clone_box(&*self.role_bounds); + let role_bounds = self + .role_bounds + .as_ref() + .map(|rb| dyn_clone::clone_box(&**rb)); let login_url = self.login_url.clone(); let redirect_field_name = self.redirect_field_name.clone(); Box::pin(async move { @@ -215,7 +221,11 @@ where .expect("Auth extension missing. Is the auth layer installed?"); match user { - Some(user) if role_bounds.contains(user.get_role()) => { + Some(user) + if role_bounds + .map(|rb| rb.contains(user.get_role())) + .unwrap_or(true) => + { let user = user.clone(); request.extensions_mut().insert(user); @@ -270,7 +280,7 @@ where tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> { login_url: None, redirect_field_name: None, - role_bounds: Box::new(..), + role_bounds: None, _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData, @@ -289,7 +299,7 @@ where tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> { login_url: None, redirect_field_name: None, - role_bounds: Box::new(role_bounds), + role_bounds: Some(Box::new(role_bounds)), _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData, @@ -314,7 +324,7 @@ where tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> { login_url: Some(login_url), redirect_field_name, - role_bounds: Box::new(..), + role_bounds: None, _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData, @@ -341,7 +351,7 @@ where tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> { login_url: Some(login_url), redirect_field_name, - role_bounds: Box::new(role_bounds), + role_bounds: Some(Box::new(role_bounds)), _user_id_type: PhantomData, _user_type: PhantomData, _body_type: PhantomData, @@ -370,9 +380,16 @@ mod tests { AuthLayer, AuthUser, }; + #[derive(Debug, Clone, PartialEq, PartialOrd)] + enum Role { + User, + Admin, + } + #[derive(Debug, Default, Clone)] struct User { id: usize, + role: Option, password_hash: String, } @@ -385,17 +402,22 @@ mod tests { } } - impl AuthUser for User { + impl AuthUser for User { fn get_id(&self) -> usize { self.id } + fn get_role(&self) -> Option { + self.role.clone() + } + fn get_password_hash(&self) -> secrecy::SecretVec { secrecy::SecretVec::new(self.password_hash.clone().into()) } } - type Auth = AuthContext>; + type Auth = AuthContext, Role>; + type RequireAuth = crate::auth::RequireAuthorizationLayer; #[tokio::test] async fn logs_user_in() { @@ -464,7 +486,6 @@ mod tests { Ok(Response::new(req.into_body())) } - type RequireAuth = crate::auth::RequireAuthorizationLayer; #[tokio::test] async fn redirects_to_login_url() { @@ -626,4 +647,188 @@ mod tests { .unwrap(); assert_eq!(res.status(), StatusCode::OK); } + + #[tokio::test] + async fn login_with_role_or_redirect() { + let secret = rand::thread_rng().gen::<[u8; 64]>(); + + let store = MemoryStore::new(); + let session_layer = SessionLayer::new(store, &secret); + + let store = Arc::new(RwLock::new(HashMap::default())); + let user = User::get_rusty_user(); + store.write().await.insert(user.get_id(), user); + + let user_store = AuthMemoryStore::new(&store); + let auth_layer = AuthLayer::new(user_store, &secret); + + let login_url = Arc::new("/login".into()); + + let mut service = ServiceBuilder::new() + .layer(session_layer.clone()) + .layer(auth_layer.clone()) + .service_fn(login); + + let mut protected_service = ServiceBuilder::new() + .layer(session_layer) + .layer(auth_layer) + .layer(RequireAuth::login_with_role_or_redirect( + Role::Admin.., + Arc::clone(&login_url), + None, + )) + .service_fn(login); + + let request = Request::get("/protected").body(Body::empty()).unwrap(); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); + assert_eq!( + res.headers().get(http::header::LOCATION), + Some(&login_url.as_ref().as_ref().try_into().unwrap()) + ); + let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); + + let mut request = Request::get("/login").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let mut request = Request::get("/protected").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); + assert_eq!( + res.headers().get(http::header::LOCATION), + Some(&login_url.as_ref().as_ref().try_into().unwrap()) + ); + for (role, status) in [ + (Role::User, StatusCode::TEMPORARY_REDIRECT), + (Role::Admin, StatusCode::OK), + ] { + let mut user = User::get_rusty_user(); + user.role = Some(role); + store.write().await.insert(user.get_id(), user); + + let mut request = Request::get("/login").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let mut request = Request::get("/protected").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), status); + } + } + + #[tokio::test] + async fn login_with_role() { + let secret = rand::thread_rng().gen::<[u8; 64]>(); + + let store = MemoryStore::new(); + let session_layer = SessionLayer::new(store, &secret); + + let store = Arc::new(RwLock::new(HashMap::default())); + let user = User::get_rusty_user(); + store.write().await.insert(user.get_id(), user); + + let user_store = AuthMemoryStore::new(&store); + let auth_layer = AuthLayer::new(user_store, &secret); + + let mut service = ServiceBuilder::new() + .layer(session_layer.clone()) + .layer(auth_layer.clone()) + .service_fn(login); + + let mut protected_service = ServiceBuilder::new() + .layer(session_layer) + .layer(auth_layer) + .layer(RequireAuth::login_with_role(Role::Admin..)) + .service_fn(login); + + let request = Request::get("/protected").body(Body::empty()).unwrap(); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); + + let mut request = Request::get("/login").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let mut request = Request::get("/protected").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + for (role, status) in [ + (Role::User, StatusCode::UNAUTHORIZED), + (Role::Admin, StatusCode::OK), + ] { + let mut user = User::get_rusty_user(); + user.role = Some(role); + store.write().await.insert(user.get_id(), user); + + let mut request = Request::get("/login").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let mut request = Request::get("/protected").body(Body::empty()).unwrap(); + request + .headers_mut() + .insert(COOKIE, session_cookie.to_owned()); + let res = protected_service + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(res.status(), status); + } + } }