diff --git a/src/auth.rs b/src/auth.rs index aac685b..faafc5e 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -19,7 +19,7 @@ use uuid::Uuid; /// JWT Token Claims #[derive(Debug, Deserialize)] -pub struct Claims { +pub struct HostClaims { /// Token subject (machine hostname) pub sub: String, } @@ -34,7 +34,7 @@ pub fn get_token_subject(token: &str) -> Result { v.insecure_disable_signature_validation(); v.set_required_spec_claims(&["sub"]); let k = DecodingKey::from_secret(b""); - let data: TokenData = decode(token, &k, &v)?; + let data: TokenData = decode(token, &k, &v)?; Ok(data.claims.sub) } @@ -46,12 +46,12 @@ pub fn get_token_subject(token: &str) -> Result { /// `service` argument, and is within its validity period (not before/expires). /// The token must be signed with HMAC-SHA256 using the host's machine ID as /// the secret key. -pub fn validate_token( +pub fn validate_host_token( token: &str, hostname: &str, machine_id: &Uuid, service: &str, -) -> Result { +) -> Result { let mut v = Validation::new(Algorithm::HS256); v.validate_nbf = true; v.set_issuer(&[hostname]); @@ -66,7 +66,7 @@ pub fn validate_token( OsRng.fill_bytes(&mut secret); } let k = DecodingKey::from_secret(&secret); - let data: TokenData = decode(token, &k, &v)?; + let data: TokenData = decode(token, &k, &v)?; Ok(data.claims) } @@ -130,7 +130,12 @@ pub(crate) mod test { let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9"); let token = make_token(hostname, machine_id); - validate_token(&token, hostname, &machine_id, "sshca.example.org") - .unwrap(); + validate_host_token( + &token, + hostname, + &machine_id, + "sshca.example.org", + ) + .unwrap(); } } diff --git a/src/server/host.rs b/src/server/host.rs index ebfa526..ba9cb83 100644 --- a/src/server/host.rs +++ b/src/server/host.rs @@ -1,16 +1,26 @@ use std::collections::HashMap; -use std::time::Duration; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use axum::async_trait; +use axum::extract::FromRequestParts; +use axum::headers::authorization::Bearer; +use axum::headers::{Authorization, Host}; +use axum::http::request::Parts; use axum::extract::multipart::{Multipart, MultipartError}; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; +use axum::{RequestPartsExt, TypedHeader}; use serde::Serialize; use ssh_key::Algorithm; use tracing::{debug, error, info, warn}; +use uuid::Uuid; -use crate::auth::Claims; +use crate::auth::{self, HostClaims}; use crate::ca; +use crate::machine_id; +use super::{AuthError, Context}; #[derive(Serialize)] pub struct SignKeyResponse { @@ -80,6 +90,71 @@ impl IntoResponse for SignKeyError { } } +#[async_trait] +impl FromRequestParts> for HostClaims { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + ctx: &super::State, + ) -> Result { + let TypedHeader(Authorization(bearer)) = parts + .extract::>>() + .await + .map_err(|e| { + debug!("Failed to extract token from HTTP request: {}", e); + AuthError + })?; + let host = parts.extract::>().await.map_or_else( + |_| "localhost".to_owned(), + |v| v.hostname().to_owned(), + ); + + let hostname = + auth::get_token_subject(bearer.token()).map_err(|e| { + debug!("Could not get token subject: {}", e); + AuthError + })?; + let machine_id = + get_machine_id(&hostname, ctx).await.ok_or_else(|| { + debug!("No machine ID found for host {}", hostname); + AuthError + })?; + let claims = auth::validate_host_token( + bearer.token(), + &hostname, + &machine_id, + &host, + ) + .map_err(|e| { + debug!("Invalid auth token: {}", e); + AuthError + })?; + debug!("Successfully authenticated request from host {}", hostname); + Ok(claims) + } +} + +async fn get_machine_id(hostname: &str, ctx: &super::State) -> Option { + let cache = ctx.cache.read().await; + if let Some((ts, m)) = cache.get(hostname) { + if ts.elapsed() < Duration::from_secs(60) { + debug!("Found cached machine ID for {}", hostname); + return Some(*m); + } else { + debug!("Cached machine ID for {} has expired", hostname); + } + } + drop(cache); + let machine_id = + machine_id::get_machine_id(hostname, ctx.config.clone()).await?; + let mut cache = ctx.cache.write().await; + debug!("Caching machine ID for {}", hostname); + cache.insert(hostname.into(), (Instant::now(), machine_id)); + Some(machine_id) +} + + #[derive(Default)] struct SignKeyRequest { hostname: String, @@ -87,7 +162,7 @@ struct SignKeyRequest { } pub(super) async fn sign_host_cert( - claims: Claims, + claims: HostClaims, State(ctx): State, mut form: Multipart, ) -> Result { @@ -136,8 +211,7 @@ pub(super) async fn sign_host_cert( pubkey.algorithm().as_str(), hostname ); - let cert = - ca::sign_cert(&hostname, &pubkey, duration, &privkey, &[])?; + let cert = ca::sign_cert(&hostname, &pubkey, duration, &privkey, &[])?; info!( "Signed {} key for {}", pubkey.algorithm().as_str(), diff --git a/src/server/mod.rs b/src/server/mod.rs index 08f1a53..79a03f1 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,24 +2,16 @@ mod host; use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; -use axum::async_trait; -use axum::extract::FromRequestParts; -use axum::headers::authorization::Bearer; -use axum::headers::{Authorization, Host}; -use axum::http::request::Parts; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; -use axum::{RequestPartsExt, Router, TypedHeader}; +use axum::Router; use tokio::sync::RwLock; -use tracing::debug; use uuid::Uuid; -use crate::auth::{self, Claims}; use crate::config::Configuration; -use crate::machine_id; struct Context { config: Arc, @@ -36,51 +28,6 @@ impl IntoResponse for AuthError { } } -#[async_trait] -impl FromRequestParts> for Claims { - type Rejection = AuthError; - - async fn from_request_parts( - parts: &mut Parts, - ctx: &State, - ) -> Result { - let TypedHeader(Authorization(bearer)) = parts - .extract::>>() - .await - .map_err(|e| { - debug!("Failed to extract token from HTTP request: {}", e); - AuthError - })?; - let host = parts.extract::>().await.map_or_else( - |_| "localhost".to_owned(), - |v| v.hostname().to_owned(), - ); - - let hostname = - auth::get_token_subject(bearer.token()).map_err(|e| { - debug!("Could not get token subject: {}", e); - AuthError - })?; - let machine_id = - get_machine_id(&hostname, ctx).await.ok_or_else(|| { - debug!("No machine ID found for host {}", hostname); - AuthError - })?; - let claims = auth::validate_token( - bearer.token(), - &hostname, - &machine_id, - &host, - ) - .map_err(|e| { - debug!("Invalid auth token: {}", e); - AuthError - })?; - debug!("Successfully authenticated request from host {}", hostname); - Ok(claims) - } -} - pub fn make_app(config: Configuration) -> Router { let ctx = Arc::new(Context { config: config.into(), @@ -92,25 +39,6 @@ pub fn make_app(config: Configuration) -> Router { .with_state(ctx) } -async fn get_machine_id(hostname: &str, ctx: &State) -> Option { - let cache = ctx.cache.read().await; - if let Some((ts, m)) = cache.get(hostname) { - if ts.elapsed() < Duration::from_secs(60) { - debug!("Found cached machine ID for {}", hostname); - return Some(*m); - } else { - debug!("Cached machine ID for {} has expired", hostname); - } - } - drop(cache); - let machine_id = - machine_id::get_machine_id(hostname, ctx.config.clone()).await?; - let mut cache = ctx.cache.write().await; - debug!("Caching machine ID for {}", hostname); - cache.insert(hostname.into(), (Instant::now(), machine_id)); - Some(machine_id) -} - #[cfg(test)] mod test { use axum::body::Body;