sshca/src/server/mod.rs

134 lines
3.8 KiB
Rust

mod host;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::headers::authorization::Bearer;
use axum::headers::{Authorization, Host};
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{RequestPartsExt, Router, TypedHeader};
use tokio::sync::RwLock;
use tracing::debug;
use uuid::Uuid;
use crate::auth::{self, Claims};
use crate::config::Configuration;
use crate::machine_id;
struct Context {
config: Arc<Configuration>,
cache: RwLock<HashMap<String, (Instant, Uuid)>>,
}
type State = Arc<Context>;
pub struct AuthError;
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
}
}
#[async_trait]
impl FromRequestParts<Arc<Context>> for Claims {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
ctx: &State,
) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| {
debug!("Failed to extract token from HTTP request: {}", e);
AuthError
})?;
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|_| "localhost".to_owned(),
|v| v.hostname().to_owned(),
);
let hostname =
auth::get_token_subject(bearer.token()).map_err(|e| {
debug!("Could not get token subject: {}", e);
AuthError
})?;
let machine_id =
get_machine_id(&hostname, ctx).await.ok_or_else(|| {
debug!("No machine ID found for host {}", hostname);
AuthError
})?;
let claims = auth::validate_token(
bearer.token(),
&hostname,
&machine_id,
&host,
)
.map_err(|e| {
debug!("Invalid auth token: {}", e);
AuthError
})?;
debug!("Successfully authenticated request from host {}", hostname);
Ok(claims)
}
}
pub fn make_app(config: Configuration) -> Router {
let ctx = Arc::new(Context {
config: config.into(),
cache: RwLock::new(Default::default()),
});
Router::new()
.route("/", get(|| async { "UP" }))
.route("/host/sign", post(host::sign_host_cert))
.with_state(ctx)
}
async fn get_machine_id(hostname: &str, ctx: &State) -> Option<Uuid> {
let cache = ctx.cache.read().await;
if let Some((ts, m)) = cache.get(hostname) {
if ts.elapsed() < Duration::from_secs(60) {
debug!("Found cached machine ID for {}", hostname);
return Some(*m);
} else {
debug!("Cached machine ID for {} has expired", hostname);
}
}
drop(cache);
let machine_id =
machine_id::get_machine_id(hostname, ctx.config.clone()).await?;
let mut cache = ctx.cache.write().await;
debug!("Caching machine ID for {}", hostname);
cache.insert(hostname.into(), (Instant::now(), machine_id));
Some(machine_id)
}
#[cfg(test)]
mod test {
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
use super::*;
#[tokio::test]
async fn test_up() {
let app = make_app(Configuration::default());
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"UP");
}
}