auth: Introduce User struct

The `UserClaims` structure is an implementation detail of how the JWT
encoding process works.  We do not need to expose all of the details of
the JWT, such as issuer, audience, expiration, etc. to rest of the
application.  Route handlers should only be concerned with the
information about the user, rather than the metadata about how the user
was authenticated.
This commit is contained in:
2025-04-08 20:29:47 -05:00
parent 3d2772cfc8
commit a50dca7fae
4 changed files with 97 additions and 48 deletions

View File

@@ -29,47 +29,70 @@ use crate::error::{LoginError, OidcError};
use crate::Context; use crate::Context;
#[derive(Debug, PartialEq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct UserClaims { struct UserClaims {
pub aud: String, aud: String,
pub exp: u64, exp: u64,
pub iat: u64, iat: u64,
pub iss: String, iss: String,
pub nbf: u64, nbf: u64,
pub sub: String, sub: String,
}
/// Represents an authenticated user
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct User {
id: String,
}
impl User {
pub fn new(id: &'static str) -> Self {
Self {
id: id.into(),
}
}
/// Return the user ID
///
/// The user ID is an arbitrary string assigned by the identity
/// provider. It does NOT necessarily represent a username, email
/// address, or any other name; it is simply a unique identifier.
pub fn id(&self) -> &str {
&self.id
}
} }
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for UserClaims { impl<'r> FromRequest<'r> for User {
type Error = &'static str; type Error = &'static str;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let secret = &req.rocket().state::<Context>().unwrap().jwt_secret; let secret = &req.rocket().state::<Context>().unwrap().jwt_secret;
match req.headers().get_one("Authorization") { match req.cookies().get("auth.token") {
Some(s) => match extract_token(secret, s) { Some(c) => match UserClaims::from_jwt(c.value(), secret) {
Ok(d) => Outcome::Success(d), Ok(d) => Outcome::Success(d.into()),
Err(e) => Outcome::Error((Status::Unauthorized, e)), Err(e) => {
debug!("Invalid auth token: {}", e);
Outcome::Error((Status::Unauthorized, "Unauthorized"))
},
},
None => match req.headers().get_one("Authorization") {
Some(s) => match extract_token(secret, s) {
Ok(d) => Outcome::Success(d.into()),
Err(e) => Outcome::Error((Status::Unauthorized, e)),
},
None => Outcome::Error((Status::Unauthorized, "Unauthorized")),
}, },
None => Outcome::Error((Status::Unauthorized, "Unauthorized")),
} }
} }
} }
impl UserClaims { impl From<UserClaims> for User {
pub fn new<S: Into<String>>(sub: S, ttl: u64) -> Self { fn from(claims: UserClaims) -> Self {
let now = SystemTime::now() Self { id: claims.sub }
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + ttl,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: sub.into(),
}
} }
}
impl UserClaims {
fn to_jwt( fn to_jwt(
&self, &self,
secret: &[u8], secret: &[u8],
@@ -92,20 +115,34 @@ impl UserClaims {
} }
impl Context { impl Context {
/// Create a signed JWT for a [`UserClaims`] /// Create a signed JWT for a user
pub fn make_jwt( pub fn make_jwt(
&self, &self,
claims: &UserClaims, user: &User,
ttl: u64,
) -> Result<String, jsonwebtoken::errors::Error> { ) -> Result<String, jsonwebtoken::errors::Error> {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + ttl,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: user.id().into(),
};
claims.to_jwt(&self.jwt_secret) claims.to_jwt(&self.jwt_secret)
} }
/// Decode a [`UserClaims`] from a JWT /// Decode the [`User`]
pub fn decode_jwt( pub fn decode_jwt(
&self, &self,
token: &str, token: &str,
) -> Result<UserClaims, jsonwebtoken::errors::Error> { ) -> Result<User, jsonwebtoken::errors::Error> {
UserClaims::from_jwt(token, &self.jwt_secret) Ok(UserClaims::from_jwt(token, &self.jwt_secret)?.into())
} }
} }
@@ -342,8 +379,9 @@ pub async fn oidc_callback(
let code = get_auth_code(&state.csrf, &params)?; let code = get_auth_code(&state.csrf, &params)?;
trace!("Got authorization code: {}", code); trace!("Got authorization code: {}", code);
let claims = exchange_code(code, state, ctx).await?; let claims = exchange_code(code, state, ctx).await?;
let user = let user = User {
UserClaims::new(claims.subject().as_str(), config.auth.login_ttl); id: claims.subject().to_string(),
};
let expires = OffsetDateTime::now_utc() let expires = OffsetDateTime::now_utc()
+ Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else( + Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else(
|e| { |e| {
@@ -354,7 +392,8 @@ pub async fn oidc_callback(
cookies.add( cookies.add(
Cookie::build(( Cookie::build((
"auth.token", "auth.token",
ctx.make_jwt(&user).map_err(LoginError::from)?, ctx.make_jwt(&user, config.auth.login_ttl)
.map_err(LoginError::from)?,
)) ))
.secure(true) .secure(true)
.http_only(true) .http_only(true)
@@ -491,7 +530,15 @@ mod test {
#[test] #[test]
fn test_userclaims_jwt_round_trip() { fn test_userclaims_jwt_round_trip() {
let claims = UserClaims::new("test1", 60); let now = now();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = claims.to_jwt(SECRET).unwrap(); let jwt = claims.to_jwt(SECRET).unwrap();
let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap(); let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap();
assert_eq!(claims, claims1); assert_eq!(claims, claims1);

View File

@@ -10,7 +10,7 @@ use rocket::State;
use serde::Serialize; use serde::Serialize;
use tracing::{debug, error, event, span, Level}; use tracing::{debug, error, event, span, Level};
use crate::auth::UserClaims; use crate::auth::User;
use crate::config::Config; use crate::config::Config;
use crate::Context; use crate::Context;
@@ -45,12 +45,12 @@ pub struct SavePageForm {
/// Save a visited page in SingleFile format /// Save a visited page in SingleFile format
#[rocket::post("/save", data = "<form>")] #[rocket::post("/save", data = "<form>")]
pub async fn post_page( pub async fn post_page(
user: UserClaims, user: User,
form: Form<SavePageForm>, form: Form<SavePageForm>,
ctx: &State<Context>, ctx: &State<Context>,
config: &State<Config>, config: &State<Config>,
) -> Result<Json<Page>, String> { ) -> Result<Json<Page>, String> {
match save_page(&form.url, &form.data, ctx, config, &user).await { match save_page(&form.url, &form.data, ctx, config, user.id()).await {
Ok(p) => Ok(Json(p)), Ok(p) => Ok(Json(p)),
Err(e) => { Err(e) => {
error!("Failed to save page: {}", e); error!("Failed to save page: {}", e);
@@ -65,16 +65,16 @@ pub async fn save_page(
data: &str, data: &str,
ctx: &Context, ctx: &Context,
config: &Config, config: &Config,
user: &UserClaims, user: &str,
) -> Result<Page, Error> { ) -> Result<Page, Error> {
let span = span!(Level::INFO, "save_page", url = url, user = user.sub); let span = span!(Level::INFO, "save_page", url = url, user = user);
let _guard = span.enter(); let _guard = span.enter();
let index_name = &config.meilisearch.index; let index_name = &config.meilisearch.index;
debug!("Saving page in Meilisearch index {}", index_name); debug!("Saving page in Meilisearch index {}", index_name);
let index = ctx.client.get_index(index_name).await?; let index = ctx.client.get_index(index_name).await?;
let doc = Page { let doc = Page {
id: gen_id(), id: gen_id(),
user_id: user.sub.clone(), user_id: user.into(),
timestamp: Utc::now(), timestamp: Utc::now(),
url: url.into(), url: url.into(),
title: extract_title(data), title: extract_title(data),

View File

@@ -78,7 +78,9 @@ async fn test_login() {
assert_eq!(location, "/"); assert_eq!(location, "/");
let cookie = res.cookies().get("auth.token").unwrap(); let cookie = res.cookies().get("auth.token").unwrap();
debug!("Cookie: {:?}", cookie); debug!("Cookie: {:?}", cookie);
let claims = ctx.decode_jwt(cookie.value()).unwrap();
debug!("Claims: {:?}", claims); // Check to ensure the cookie contains a valid token
assert!(!claims.sub.is_empty()); let user = ctx.decode_jwt(cookie.value()).unwrap();
debug!("User: {:?}", user);
assert!(!user.id().is_empty());
} }

View File

@@ -5,7 +5,7 @@ use rocket::local::blocking::Client;
use rocket::serde::json::Value; use rocket::serde::json::Value;
use rocket::uri; use rocket::uri;
use seensite::auth::UserClaims; use seensite::auth::User;
use seensite::page::*; use seensite::page::*;
use seensite::Context; use seensite::Context;
@@ -34,8 +34,8 @@ fn test_post_page() {
super::setup(); super::setup();
let client = Client::tracked(seensite::rocket()).unwrap(); let client = Client::tracked(seensite::rocket()).unwrap();
let ctx: &Context = client.rocket().state().unwrap(); let ctx: &Context = client.rocket().state().unwrap();
let claims = UserClaims::new("test1", 60); let user = User::new("test1");
let token = ctx.make_jwt(&claims).unwrap(); let token = ctx.make_jwt(&user, 60).unwrap();
let data = Serializer::new(String::new()) let data = Serializer::new(String::new())
.append_pair("url", TEST_URL) .append_pair("url", TEST_URL)
.append_pair("data", TEST_HTML) .append_pair("data", TEST_HTML)