use std::path::PathBuf; use std::time; use argh::FromArgs; use argon2::Argon2; use jsonwebtoken::{encode, EncodingKey}; use reqwest::multipart::{Form, Part}; use reqwest::{StatusCode, Url}; use serde::Serialize; use tracing::{debug, error, info}; use tracing_subscriber::EnvFilter; use uuid::Uuid; static RPI_SERIAL_PATH: &str = "/sys/firmware/devicetree/base/serial-number"; static DMI_UUID_PATH: &str = "/sys/class/dmi/id/product_uuid"; /// SSH CA client CLI #[derive(FromArgs)] struct Args { #[argh(subcommand)] command: Subcommand, } #[derive(FromArgs)] #[argh(subcommand)] enum Subcommand { Host(HostArgs), } /// Manage host keys and certificates #[derive(FromArgs)] #[argh(subcommand, name = "host")] struct HostArgs { #[argh(subcommand)] command: HostSubcommand, } #[derive(FromArgs)] #[argh(subcommand)] enum HostSubcommand { Sign(SignArgs), } /// Request a signed certificate for an SSH public key #[derive(FromArgs)] #[argh(subcommand, name = "sign")] struct SignArgs { /// path to destination SSH host certificate file #[argh(option, short = 'c')] output: Option, /// path to SSH host public key file to sign #[argh(positional)] pubkey: PathBuf, } #[derive(Debug, Serialize)] struct Claims { sub: String, iss: String, aud: String, iat: u64, nbf: u64, exp: u64, } type MainResult = Result<(), Box>; #[tokio::main(flavor = "current_thread")] async fn main() { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .with_writer(std::io::stderr) .init(); match inner_main().await { Ok(_) => (), Err(e) => { eprintln!("{}", e); std::process::exit(1); } } } async fn inner_main() -> MainResult { let args: Args = argh::from_env(); match args.command { Subcommand::Host(args) => host_cmd(args).await, } } async fn host_cmd(args: HostArgs) -> MainResult { match args.command { HostSubcommand::Sign(args) => sign_key(args).await, } } async fn sign_key(args: SignArgs) -> MainResult { let url = match std::env::var("SSHCA_SERVER") { Ok(v) => v, Err(std::env::VarError::NotPresent) => { return Err("SSHCA_SERVER environment variable is not set".into()); } Err(std::env::VarError::NotUnicode(_)) => { return Err("SSHCA_SERVER environment variable is invalid".into()); } }; let Some(hostname) = get_hostname() else { return Err("Hostname must be valid UTF-8".into()); }; let Some(machine_id) = get_machine_id() else { return Err("Could not determine machine ID".into()); }; let pubkey = std::fs::read_to_string(&args.pubkey)?; let mut url = Url::parse(&url).map_err(|e| format!("Invalid URL: {}", e))?; url.path_segments_mut() .map_err(|_| "Invalid URL: missing host")? .pop_if_empty() .push("host") .push("sign"); let form = Form::new().text("hostname", hostname.clone()).part( "pubkey", Part::bytes(pubkey.into_bytes()).file_name( args.pubkey .file_name() .ok_or("Invalid public key file path")? .to_str() .ok_or("Invalid public key file path")? .to_string(), ), ); let token = get_token( &hostname, &machine_id, url.host_str().ok_or("Invalid URL: missing host")?, )?; let client = reqwest::Client::new(); info!( "Requesting SSH host certificate for {} with key {}", hostname, args.pubkey.display() ); debug!("Request: POST {}", url); let res = client .post(url) .header("Authorization", format!("Bearer {}", token)) .multipart(form) .send() .await?; debug!("Response: {:?} {}", &res.version(), &res.status()); match res.error_for_status_ref() { Ok(_) => (), Err(e) if e.status() == Some(StatusCode::BAD_REQUEST) => { let msg = res.text().await.unwrap_or_else(|e| e.to_string()); error!("{}: {}", e, msg); return Err(format!("{}\n{}", e, msg).into()); } Err(e) => { error!("{}", e); return Err(e.into()); } }; let cert = res.text().await?; if let Some(path) = args.output { std::fs::write(path, cert)?; } else { println!("{}", cert); } Ok(()) } fn get_hostname() -> Option { gethostname::gethostname().into_string().ok() } fn get_machine_id() -> Option { match std::fs::read_to_string(RPI_SERIAL_PATH) { Ok(s) => match Uuid::parse_str(&format!( "{:0>32}", s.trim_end_matches('\0') )) { Ok(u) => return Some(u), Err(e) => { debug!("Invalid UUID: {}", e); } }, Err(e) => { debug!("Could not read Raspberry Pi serial number: {}", e) } }; match std::fs::read_to_string(DMI_UUID_PATH) { Ok(s) => match Uuid::parse_str(s.trim_end()) { Ok(u) => return Some(u), Err(e) => { debug!("Invalid UUID: {}", e); } }, Err(e) => { debug!("Could not read DMI product UUID from sysfs: {}", e); } }; None } fn get_token( hostname: &str, machine_id: &Uuid, server: &str, ) -> Result> { let now = time::SystemTime::now() .duration_since(time::UNIX_EPOCH)? .as_secs(); let claims = Claims { sub: hostname.into(), iss: hostname.into(), aud: server.split(':').next().unwrap().into(), nbf: now - 60, iat: now, exp: now + 60, }; let mut secret = [0u8; 32]; Argon2::default() .hash_password_into( machine_id.as_bytes(), hostname.as_bytes(), &mut secret, ) .map_err(|e| e.to_string())?; let key = EncodingKey::from_secret(&secret); Ok(encode(&Default::default(), &claims, &key)?) }