From 720bb690ea36f22695db9b365a41526de39f84fe Mon Sep 17 00:00:00 2001 From: "Dustin C. Hatch" Date: Sat, 5 Apr 2025 07:24:18 -0500 Subject: [PATCH] auth: Initial JWT implementation We'll use a JWT in the `Authorization` request header to identify the user saving a page. The token will need to be set in the _authorization token_ field in the SingleFile configuration so it will be included when uploading. --- .gitignore | 1 + Cargo.lock | 1 + Cargo.toml | 1 + examples/make-token.rs | 37 ++++++++ src/auth.rs | 201 +++++++++++++++++++++++++++++++++++++++++ src/config.rs | 8 ++ src/main.rs | 15 ++- src/page.rs | 7 +- 8 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 examples/make-token.rs create mode 100644 src/auth.rs diff --git a/.gitignore b/.gitignore index 2c33aa5..704da5f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target /Rocket.toml +/jwt.secret /meilisearch.token diff --git a/Cargo.lock b/Cargo.lock index e082116..f022629 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1786,6 +1786,7 @@ version = "0.1.0" dependencies = [ "chrono", "html5ever", + "jsonwebtoken", "markup5ever_rcdom", "meilisearch-sdk", "rand 0.9.0", diff --git a/Cargo.toml b/Cargo.toml index 5a9a11c..8f5a836 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4.40", default-features = false, features = ["std", "clock", "serde"] } html5ever = "0.27.0" +jsonwebtoken = { version = "9.3.1", default-features = false } markup5ever_rcdom = "0.3.0" meilisearch-sdk = "0.28.0" rand = "0.9.0" diff --git a/examples/make-token.rs b/examples/make-token.rs new file mode 100644 index 0000000..d868cc0 --- /dev/null +++ b/examples/make-token.rs @@ -0,0 +1,37 @@ +use std::io::Read; +use std::time::SystemTime; + +use jsonwebtoken::{encode, EncodingKey, Header}; +use serde::Serialize; + +#[derive(Debug, Serialize)] +struct UserClaims { + aud: String, + exp: u64, + iat: u64, + iss: String, + nbf: u64, + sub: String, +} + +fn main() { + let args: Vec<_> = std::env::args().collect(); + let mut secret = vec![]; + let mut f = std::fs::File::open(&args[1]).unwrap(); + f.read_to_end(&mut secret).unwrap(); + let k = EncodingKey::from_secret(&secret); + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 604800, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: args[2].to_string(), + }; + let jwt = encode(&Header::default(), &claims, &k).unwrap(); + println!("{}", jwt); +} diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..e6420a6 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,201 @@ +use std::io::Read; + +use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome, Request}; +use serde::{Deserialize, Serialize}; +use tracing::error; + +use crate::config::Config; +use crate::Context; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct UserClaims { + pub aud: String, + pub exp: u64, + pub iat: u64, + pub iss: String, + pub nbf: u64, + pub sub: String, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for UserClaims { + type Error = &'static str; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + let secret = &req.rocket().state::().unwrap().jwt_secret; + match req.headers().get_one("Authorization") { + Some(s) => match extract_token(secret, s) { + Ok(d) => Outcome::Success(d.claims), + Err(e) => Outcome::Error((Status::Unauthorized, e)), + }, + None => Outcome::Error((Status::Unauthorized, "Unauthorized")), + } + } +} + +/// Load the JWT secret from the path specified in the configuration +pub fn load_secret(config: &Config) -> Result, std::io::Error> { + let mut secret = vec![]; + let mut f = std::fs::File::open(&config.auth.jwt_secret)?; + f.read_to_end(&mut secret)?; + Ok(secret) +} + +fn extract_token( + secret: &[u8], + header: &str, +) -> Result, &'static str> { + let mut iter = header.split_ascii_whitespace(); + let scheme = iter.next(); + let token = iter.next(); + if token.is_some() && Some("Bearer") != scheme { + return Err("Unsupported authorization scheme"); + } + if let Some(token) = token { + let mut v = Validation::new(Algorithm::HS256); + v.validate_nbf = true; + v.set_issuer(&[env!("CARGO_PKG_NAME")]); + v.set_audience(&[env!("CARGO_PKG_NAME")]); + let k = DecodingKey::from_secret(secret); + match decode(token, &k, &v) { + Ok(d) => Ok(d), + Err(e) => { + error!("Failed to decode auth token: {}", e); + Err("Invalid token") + }, + } + } else { + Err("Invalid Authorization header") + } +} + +#[cfg(test)] +mod test { + use super::*; + use jsonwebtoken::{encode, EncodingKey, Header}; + use std::time::SystemTime; + + static SECRET: &[u8; 32] = b"6gmrLQ4PGjeUpT3Xs48thx9Cu6XE5pgD"; + + fn now() -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + } + + fn make_token(claims: &UserClaims) -> String { + let k = EncodingKey::from_secret(SECRET); + encode(&Header::default(), claims, &k).unwrap() + } + + #[test] + fn test_extract_token() { + let now = now(); + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let header = format!("Bearer {}", jwt); + let data = extract_token(SECRET, &header).unwrap(); + assert_eq!(claims, data.claims); + } + + #[test] + fn test_extract_token_expired() { + let now = now() - 600; + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let header = format!("Bearer {}", jwt); + let err = extract_token(SECRET, &header).unwrap_err(); + assert_eq!(err, "Invalid token"); + } + + #[test] + fn test_extract_token_future() { + let now = now() + 600; + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let header = format!("Bearer {}", jwt); + let err = extract_token(SECRET, &header).unwrap_err(); + assert_eq!(err, "Invalid token"); + } + + #[test] + fn test_extract_token_bad_issuer() { + let now = now(); + let claims = UserClaims { + aud: env!("CARGO_PKG_NAME").into(), + exp: now + 60, + iat: now, + iss: "mallory".into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let header = format!("Bearer {}", jwt); + let err = extract_token(SECRET, &header).unwrap_err(); + assert_eq!(err, "Invalid token"); + } + + #[test] + fn test_extract_token_bad_aud() { + let now = now(); + let claims = UserClaims { + aud: "mallory".into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let header = format!("Bearer {}", jwt); + let err = extract_token(SECRET, &header).unwrap_err(); + assert_eq!(err, "Invalid token"); + } + + #[test] + fn test_malformed_header() { + let now = now(); + let claims = UserClaims { + aud: "mallory".into(), + exp: now + 60, + iat: now, + iss: env!("CARGO_PKG_NAME").into(), + nbf: now - 60, + sub: "test1".into(), + }; + let jwt = make_token(&claims); + let err = extract_token(SECRET, &jwt).unwrap_err(); + assert_eq!(err, "Invalid Authorization header"); + } + + #[test] + fn test_unsupported_auth_scheme() { + let header = "Basic dXNlcm5hbWU6cGFzc3dvcmQ="; + let err = extract_token(SECRET, header).unwrap_err(); + assert_eq!(err, "Unsupported authorization scheme"); + } +} diff --git a/src/config.rs b/src/config.rs index 96fe03d..ff93815 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use serde::Deserialize; #[derive(Debug, Deserialize)] @@ -9,8 +11,14 @@ pub struct MeilisearchConfig { pub index: String, } +#[derive(Debug, Deserialize)] +pub struct AuthConfig { + pub jwt_secret: PathBuf, +} + #[derive(Debug, Deserialize)] pub struct Config { + pub auth: AuthConfig, pub meilisearch: MeilisearchConfig, } diff --git a/src/main.rs b/src/main.rs index 24fcd5d..8b81979 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod auth; mod config; mod meilisearch; mod page; @@ -15,18 +16,27 @@ struct Context { #[allow(dead_code)] config: Config, client: MeilisearchClient, + jwt_secret: Vec, } #[derive(Debug, thiserror::Error)] enum InitError { #[error("Meilisearch error: {0}")] Meilisearch(#[from] meilisearch::Error), + #[error("Failed to load JWT secret: {0}")] + LoadSecret(std::io::Error), } impl Context { pub fn init(config: Config) -> Result { let client = MeilisearchClient::try_from(&config)?; - Ok(Self { config, client }) + let jwt_secret = + auth::load_secret(&config).map_err(InitError::LoadSecret)?; + Ok(Self { + config, + client, + jwt_secret, + }) } } @@ -42,10 +52,11 @@ pub struct SavePageForm { /// Save a visited page in SingleFile format #[rocket::post("/save", data = "
")] async fn save_page( + user: auth::UserClaims, form: Form, ctx: &State, ) -> Result, String> { - match page::save_page(&form.url, &form.data, ctx).await { + match page::save_page(&form.url, &form.data, ctx, &user).await { Ok(p) => Ok(Json(p)), Err(e) => { error!("Failed to save page: {}", e); diff --git a/src/page.rs b/src/page.rs index 0a4c9e6..838f8d1 100644 --- a/src/page.rs +++ b/src/page.rs @@ -7,6 +7,7 @@ use rand::Rng; use serde::Serialize; use tracing::{debug, event, span, Level}; +use crate::auth::UserClaims; use crate::Context; static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; @@ -16,6 +17,8 @@ static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; pub struct Page { /// Unique saved page ID id: String, + /// User ID of page owner + user_id: String, /// Visit timestamp timestamp: DateTime, /// Page URL @@ -31,14 +34,16 @@ pub async fn save_page( url: &str, data: &str, ctx: &Context, + user: &UserClaims, ) -> Result { - let span = span!(Level::INFO, "save_page", url = url); + let span = span!(Level::INFO, "save_page", url = url, user = user.sub); let _guard = span.enter(); let index_name = &ctx.config.meilisearch.index; debug!("Saving page in Meilisearch index {}", index_name); let index = ctx.client.get_index(index_name).await?; let doc = Page { id: gen_id(), + user_id: user.sub.clone(), timestamp: Utc::now(), url: url.into(), title: extract_title(data),