use std::io::Read; use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; use rocket::http::Status; use rocket::request::{FromRequest, Outcome, Request}; use serde::{Deserialize, Serialize}; use tracing::error; use crate::config::Config; use crate::Context; #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct UserClaims { pub aud: String, pub exp: u64, pub iat: u64, pub iss: String, pub nbf: u64, pub sub: String, } #[rocket::async_trait] impl<'r> FromRequest<'r> for UserClaims { type Error = &'static str; async fn from_request(req: &'r Request<'_>) -> Outcome { let secret = &req.rocket().state::().unwrap().jwt_secret; match req.headers().get_one("Authorization") { Some(s) => match extract_token(secret, s) { Ok(d) => Outcome::Success(d.claims), Err(e) => Outcome::Error((Status::Unauthorized, e)), }, None => Outcome::Error((Status::Unauthorized, "Unauthorized")), } } } /// Load the JWT secret from the path specified in the configuration pub fn load_secret(config: &Config) -> Result, std::io::Error> { let mut secret = vec![]; let mut f = std::fs::File::open(&config.auth.jwt_secret)?; f.read_to_end(&mut secret)?; Ok(secret) } fn extract_token( secret: &[u8], header: &str, ) -> Result, &'static str> { let mut iter = header.split_ascii_whitespace(); let scheme = iter.next(); let token = iter.next(); if token.is_some() && Some("Bearer") != scheme { return Err("Unsupported authorization scheme"); } if let Some(token) = token { let mut v = Validation::new(Algorithm::HS256); v.validate_nbf = true; v.set_issuer(&[env!("CARGO_PKG_NAME")]); v.set_audience(&[env!("CARGO_PKG_NAME")]); let k = DecodingKey::from_secret(secret); match decode(token, &k, &v) { Ok(d) => Ok(d), Err(e) => { error!("Failed to decode auth token: {}", e); Err("Invalid token") }, } } else { Err("Invalid Authorization header") } } #[cfg(test)] mod test { use super::*; use jsonwebtoken::{encode, EncodingKey, Header}; use std::time::SystemTime; static SECRET: &[u8; 32] = b"6gmrLQ4PGjeUpT3Xs48thx9Cu6XE5pgD"; fn now() -> u64 { SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs() } fn make_token(claims: &UserClaims) -> String { let k = EncodingKey::from_secret(SECRET); encode(&Header::default(), claims, &k).unwrap() } #[test] fn test_extract_token() { let now = now(); let claims = UserClaims { aud: env!("CARGO_PKG_NAME").into(), exp: now + 60, iat: now, iss: env!("CARGO_PKG_NAME").into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let header = format!("Bearer {}", jwt); let data = extract_token(SECRET, &header).unwrap(); assert_eq!(claims, data.claims); } #[test] fn test_extract_token_expired() { let now = now() - 600; let claims = UserClaims { aud: env!("CARGO_PKG_NAME").into(), exp: now + 60, iat: now, iss: env!("CARGO_PKG_NAME").into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let header = format!("Bearer {}", jwt); let err = extract_token(SECRET, &header).unwrap_err(); assert_eq!(err, "Invalid token"); } #[test] fn test_extract_token_future() { let now = now() + 600; let claims = UserClaims { aud: env!("CARGO_PKG_NAME").into(), exp: now + 60, iat: now, iss: env!("CARGO_PKG_NAME").into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let header = format!("Bearer {}", jwt); let err = extract_token(SECRET, &header).unwrap_err(); assert_eq!(err, "Invalid token"); } #[test] fn test_extract_token_bad_issuer() { let now = now(); let claims = UserClaims { aud: env!("CARGO_PKG_NAME").into(), exp: now + 60, iat: now, iss: "mallory".into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let header = format!("Bearer {}", jwt); let err = extract_token(SECRET, &header).unwrap_err(); assert_eq!(err, "Invalid token"); } #[test] fn test_extract_token_bad_aud() { let now = now(); let claims = UserClaims { aud: "mallory".into(), exp: now + 60, iat: now, iss: env!("CARGO_PKG_NAME").into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let header = format!("Bearer {}", jwt); let err = extract_token(SECRET, &header).unwrap_err(); assert_eq!(err, "Invalid token"); } #[test] fn test_malformed_header() { let now = now(); let claims = UserClaims { aud: "mallory".into(), exp: now + 60, iat: now, iss: env!("CARGO_PKG_NAME").into(), nbf: now - 60, sub: "test1".into(), }; let jwt = make_token(&claims); let err = extract_token(SECRET, &jwt).unwrap_err(); assert_eq!(err, "Invalid Authorization header"); } #[test] fn test_unsupported_auth_scheme() { let header = "Basic dXNlcm5hbWU6cGFzc3dvcmQ="; let err = extract_token(SECRET, header).unwrap_err(); assert_eq!(err, "Unsupported authorization scheme"); } }