diff --git a/src/auth.rs b/src/auth.rs index dd13f5f..6f28e13 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -29,47 +29,70 @@ use crate::error::{LoginError, OidcError}; use crate::Context; #[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct UserClaims { - pub aud: String, - pub exp: u64, - pub iat: u64, - pub iss: String, - pub nbf: u64, - pub sub: String, +struct UserClaims { + aud: String, + exp: u64, + iat: u64, + iss: String, + nbf: u64, + sub: String, +} + +/// Represents an authenticated user +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct User { + id: String, +} + +impl User { + pub fn new(id: &'static str) -> Self { + Self { + id: id.into(), + } + } + + /// Return the user ID + /// + /// The user ID is an arbitrary string assigned by the identity + /// provider. It does NOT necessarily represent a username, email + /// address, or any other name; it is simply a unique identifier. + pub fn id(&self) -> &str { + &self.id + } } #[rocket::async_trait] -impl<'r> FromRequest<'r> for UserClaims { +impl<'r> FromRequest<'r> for User { type Error = &'static str; async fn from_request(req: &'r Request<'_>) -> Outcome { let secret = &req.rocket().state::().unwrap().jwt_secret; - match req.headers().get_one("Authorization") { - Some(s) => match extract_token(secret, s) { - Ok(d) => Outcome::Success(d), - Err(e) => Outcome::Error((Status::Unauthorized, e)), + match req.cookies().get("auth.token") { + Some(c) => match UserClaims::from_jwt(c.value(), secret) { + Ok(d) => Outcome::Success(d.into()), + Err(e) => { + debug!("Invalid auth token: {}", e); + Outcome::Error((Status::Unauthorized, "Unauthorized")) + }, + }, + None => match req.headers().get_one("Authorization") { + Some(s) => match extract_token(secret, s) { + Ok(d) => Outcome::Success(d.into()), + Err(e) => Outcome::Error((Status::Unauthorized, e)), + }, + None => Outcome::Error((Status::Unauthorized, "Unauthorized")), }, - None => Outcome::Error((Status::Unauthorized, "Unauthorized")), } } } -impl UserClaims { - pub fn new>(sub: S, ttl: u64) -> Self { - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs(); - Self { - aud: env!("CARGO_PKG_NAME").into(), - exp: now + ttl, - iat: now, - iss: env!("CARGO_PKG_NAME").into(), - nbf: now - 60, - sub: sub.into(), - } +impl From for User { + fn from(claims: UserClaims) -> Self { + Self { id: claims.sub } } +} +impl UserClaims { fn to_jwt( &self, secret: &[u8], @@ -92,20 +115,34 @@ impl UserClaims { } impl Context { - /// Create a signed JWT for a [`UserClaims`] + /// Create a signed JWT for a user pub fn make_jwt( &self, - claims: &UserClaims, + user: &User, + ttl: u64, ) -> Result { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + ttl, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: user.id().into(), + }; claims.to_jwt(&self.jwt_secret) } - /// Decode a [`UserClaims`] from a JWT + /// Decode the [`User`] pub fn decode_jwt( &self, token: &str, - ) -> Result { - UserClaims::from_jwt(token, &self.jwt_secret) + ) -> Result { + Ok(UserClaims::from_jwt(token, &self.jwt_secret)?.into()) } } @@ -342,8 +379,9 @@ pub async fn oidc_callback( let code = get_auth_code(&state.csrf, ¶ms)?; trace!("Got authorization code: {}", code); let claims = exchange_code(code, state, ctx).await?; - let user = - UserClaims::new(claims.subject().as_str(), config.auth.login_ttl); + let user = User { + id: claims.subject().to_string(), + }; let expires = OffsetDateTime::now_utc() + Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else( |e| { @@ -354,7 +392,8 @@ pub async fn oidc_callback( cookies.add( Cookie::build(( "auth.token", - ctx.make_jwt(&user).map_err(LoginError::from)?, + ctx.make_jwt(&user, config.auth.login_ttl) + .map_err(LoginError::from)?, )) .secure(true) .http_only(true) @@ -491,7 +530,15 @@ mod test { #[test] fn test_userclaims_jwt_round_trip() { - let claims = UserClaims::new("test1", 60); + let now = now(); + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; let jwt = claims.to_jwt(SECRET).unwrap(); let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap(); assert_eq!(claims, claims1); diff --git a/src/page.rs b/src/page.rs index 05d393e..7c27108 100644 --- a/src/page.rs +++ b/src/page.rs @@ -10,7 +10,7 @@ use rocket::State; use serde::Serialize; use tracing::{debug, error, event, span, Level}; -use crate::auth::UserClaims; +use crate::auth::User; use crate::config::Config; use crate::Context; @@ -45,12 +45,12 @@ pub struct SavePageForm { /// Save a visited page in SingleFile format #[rocket::post("/save", data = "
")] pub async fn post_page( - user: UserClaims, + user: User, form: Form, ctx: &State, config: &State, ) -> Result, String> { - match save_page(&form.url, &form.data, ctx, config, &user).await { + match save_page(&form.url, &form.data, ctx, config, user.id()).await { Ok(p) => Ok(Json(p)), Err(e) => { error!("Failed to save page: {}", e); @@ -65,16 +65,16 @@ pub async fn save_page( data: &str, ctx: &Context, config: &Config, - user: &UserClaims, + user: &str, ) -> Result { - let span = span!(Level::INFO, "save_page", url = url, user = user.sub); + let span = span!(Level::INFO, "save_page", url = url, user = user); let _guard = span.enter(); let index_name = &config.meilisearch.index; debug!("Saving page in Meilisearch index {}", index_name); let index = ctx.client.get_index(index_name).await?; let doc = Page { id: gen_id(), - user_id: user.sub.clone(), + user_id: user.into(), timestamp: Utc::now(), url: url.into(), title: extract_title(data), diff --git a/tests/integration/auth.rs b/tests/integration/auth.rs index e499d68..c806586 100644 --- a/tests/integration/auth.rs +++ b/tests/integration/auth.rs @@ -78,7 +78,9 @@ async fn test_login() { assert_eq!(location, "/"); let cookie = res.cookies().get("auth.token").unwrap(); debug!("Cookie: {:?}", cookie); - let claims = ctx.decode_jwt(cookie.value()).unwrap(); - debug!("Claims: {:?}", claims); - assert!(!claims.sub.is_empty()); + + // Check to ensure the cookie contains a valid token + let user = ctx.decode_jwt(cookie.value()).unwrap(); + debug!("User: {:?}", user); + assert!(!user.id().is_empty()); } diff --git a/tests/integration/page.rs b/tests/integration/page.rs index 64ce70e..9a1757c 100644 --- a/tests/integration/page.rs +++ b/tests/integration/page.rs @@ -5,7 +5,7 @@ use rocket::local::blocking::Client; use rocket::serde::json::Value; use rocket::uri; -use seensite::auth::UserClaims; +use seensite::auth::User; use seensite::page::*; use seensite::Context; @@ -34,8 +34,8 @@ fn test_post_page() { super::setup(); let client = Client::tracked(seensite::rocket()).unwrap(); let ctx: &Context = client.rocket().state().unwrap(); - let claims = UserClaims::new("test1", 60); - let token = ctx.make_jwt(&claims).unwrap(); + let user = User::new("test1"); + let token = ctx.make_jwt(&user, 60).unwrap(); let data = Serializer::new(String::new()) .append_pair("url", TEST_URL) .append_pair("data", TEST_HTML)