Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make role_bounds optional to require just login #67

Merged
merged 1 commit into from
May 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 217 additions & 12 deletions axum-login/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ where
if let Some(role) = role {
RangeBounds::contains(self, &role)
} else {
role.is_none()
role.is_some()
}
}
}
Expand All @@ -171,7 +171,7 @@ where
pub struct Login<UserId, User, ResBody, Role = ()> {
login_url: Option<Arc<Cow<'static, str>>>,
redirect_field_name: Option<Arc<Cow<'static, str>>>,
role_bounds: Box<dyn RoleBounds<Role>>,
role_bounds: Option<Box<dyn RoleBounds<Role>>>,
_user_id_type: PhantomData<UserId>,
_user_type: PhantomData<User>,
_body_type: PhantomData<fn() -> ResBody>,
Expand All @@ -182,7 +182,10 @@ impl<UserId, User, ResBody, Role> Clone for Login<UserId, User, ResBody, Role> {
Self {
login_url: self.login_url.clone(),
redirect_field_name: self.redirect_field_name.clone(),
role_bounds: dyn_clone::clone_box(&*self.role_bounds),
role_bounds: self
.role_bounds
.as_ref()
.map(|rb| dyn_clone::clone_box(&**rb)),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
Expand All @@ -205,7 +208,10 @@ where
BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;

fn authorize(&mut self, mut request: Request<ReqBody>) -> Self::Future {
let role_bounds = dyn_clone::clone_box(&*self.role_bounds);
let role_bounds = self
.role_bounds
.as_ref()
.map(|rb| dyn_clone::clone_box(&**rb));
let login_url = self.login_url.clone();
let redirect_field_name = self.redirect_field_name.clone();
Box::pin(async move {
Expand All @@ -215,7 +221,11 @@ where
.expect("Auth extension missing. Is the auth layer installed?");

match user {
Some(user) if role_bounds.contains(user.get_role()) => {
Some(user)
if role_bounds
.map(|rb| rb.contains(user.get_role()))
.unwrap_or(true) =>
{
let user = user.clone();
request.extensions_mut().insert(user);

Expand Down Expand Up @@ -270,7 +280,7 @@ where
tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> {
login_url: None,
redirect_field_name: None,
role_bounds: Box::new(..),
role_bounds: None,
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
Expand All @@ -289,7 +299,7 @@ where
tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> {
login_url: None,
redirect_field_name: None,
role_bounds: Box::new(role_bounds),
role_bounds: Some(Box::new(role_bounds)),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
Expand All @@ -314,7 +324,7 @@ where
tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> {
login_url: Some(login_url),
redirect_field_name,
role_bounds: Box::new(..),
role_bounds: None,
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
Expand All @@ -341,7 +351,7 @@ where
tower_http::auth::AsyncRequireAuthorizationLayer::new(Login::<_, _, _, _> {
login_url: Some(login_url),
redirect_field_name,
role_bounds: Box::new(role_bounds),
role_bounds: Some(Box::new(role_bounds)),
_user_id_type: PhantomData,
_user_type: PhantomData,
_body_type: PhantomData,
Expand Down Expand Up @@ -370,9 +380,16 @@ mod tests {
AuthLayer, AuthUser,
};

#[derive(Debug, Clone, PartialEq, PartialOrd)]
enum Role {
User,
Admin,
}

#[derive(Debug, Default, Clone)]
struct User {
id: usize,
role: Option<Role>,
password_hash: String,
}

Expand All @@ -385,17 +402,22 @@ mod tests {
}
}

impl AuthUser<usize> for User {
impl AuthUser<usize, Role> for User {
fn get_id(&self) -> usize {
self.id
}

fn get_role(&self) -> Option<Role> {
self.role.clone()
}

fn get_password_hash(&self) -> secrecy::SecretVec<u8> {
secrecy::SecretVec::new(self.password_hash.clone().into())
}
}

type Auth = AuthContext<usize, User, AuthMemoryStore<usize, User>>;
type Auth = AuthContext<usize, User, AuthMemoryStore<usize, User>, Role>;
type RequireAuth = crate::auth::RequireAuthorizationLayer<usize, User, Role>;

#[tokio::test]
async fn logs_user_in() {
Expand Down Expand Up @@ -464,7 +486,6 @@ mod tests {

Ok(Response::new(req.into_body()))
}
type RequireAuth = crate::auth::RequireAuthorizationLayer<usize, User>;

#[tokio::test]
async fn redirects_to_login_url() {
Expand Down Expand Up @@ -626,4 +647,188 @@ mod tests {
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn login_with_role_or_redirect() {
let secret = rand::thread_rng().gen::<[u8; 64]>();

let store = MemoryStore::new();
let session_layer = SessionLayer::new(store, &secret);

let store = Arc::new(RwLock::new(HashMap::default()));
let user = User::get_rusty_user();
store.write().await.insert(user.get_id(), user);

let user_store = AuthMemoryStore::new(&store);
let auth_layer = AuthLayer::new(user_store, &secret);

let login_url = Arc::new("/login".into());

let mut service = ServiceBuilder::new()
.layer(session_layer.clone())
.layer(auth_layer.clone())
.service_fn(login);

let mut protected_service = ServiceBuilder::new()
.layer(session_layer)
.layer(auth_layer)
.layer(RequireAuth::login_with_role_or_redirect(
Role::Admin..,
Arc::clone(&login_url),
None,
))
.service_fn(login);

let request = Request::get("/protected").body(Body::empty()).unwrap();
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
assert_eq!(
res.headers().get(http::header::LOCATION),
Some(&login_url.as_ref().as_ref().try_into().unwrap())
);
let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();

let mut request = Request::get("/login").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);

let mut request = Request::get("/protected").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
assert_eq!(
res.headers().get(http::header::LOCATION),
Some(&login_url.as_ref().as_ref().try_into().unwrap())
);
for (role, status) in [
(Role::User, StatusCode::TEMPORARY_REDIRECT),
(Role::Admin, StatusCode::OK),
] {
let mut user = User::get_rusty_user();
user.role = Some(role);
store.write().await.insert(user.get_id(), user);

let mut request = Request::get("/login").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);

let mut request = Request::get("/protected").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), status);
}
}

#[tokio::test]
async fn login_with_role() {
let secret = rand::thread_rng().gen::<[u8; 64]>();

let store = MemoryStore::new();
let session_layer = SessionLayer::new(store, &secret);

let store = Arc::new(RwLock::new(HashMap::default()));
let user = User::get_rusty_user();
store.write().await.insert(user.get_id(), user);

let user_store = AuthMemoryStore::new(&store);
let auth_layer = AuthLayer::new(user_store, &secret);

let mut service = ServiceBuilder::new()
.layer(session_layer.clone())
.layer(auth_layer.clone())
.service_fn(login);

let mut protected_service = ServiceBuilder::new()
.layer(session_layer)
.layer(auth_layer)
.layer(RequireAuth::login_with_role(Role::Admin..))
.service_fn(login);

let request = Request::get("/protected").body(Body::empty()).unwrap();
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();

let mut request = Request::get("/login").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);

let mut request = Request::get("/protected").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
for (role, status) in [
(Role::User, StatusCode::UNAUTHORIZED),
(Role::Admin, StatusCode::OK),
] {
let mut user = User::get_rusty_user();
user.role = Some(role);
store.write().await.insert(user.get_id(), user);

let mut request = Request::get("/login").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);

let mut request = Request::get("/protected").body(Body::empty()).unwrap();
request
.headers_mut()
.insert(COOKIE, session_cookie.to_owned());
let res = protected_service
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(res.status(), status);
}
}
}