sshca/cli/src/main.rs

238 lines
6.1 KiB
Rust

use std::path::PathBuf;
use std::time;
use argh::FromArgs;
use argon2::Argon2;
use jsonwebtoken::{encode, EncodingKey};
use reqwest::multipart::{Form, Part};
use reqwest::{StatusCode, Url};
use serde::Serialize;
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
static RPI_SERIAL_PATH: &str = "/sys/firmware/devicetree/base/serial-number";
static DMI_UUID_PATH: &str = "/sys/class/dmi/id/product_uuid";
/// SSH CA client CLI
#[derive(FromArgs)]
struct Args {
#[argh(subcommand)]
command: Subcommand,
}
#[derive(FromArgs)]
#[argh(subcommand)]
enum Subcommand {
Host(HostArgs),
}
/// Manage host keys and certificates
#[derive(FromArgs)]
#[argh(subcommand, name = "host")]
struct HostArgs {
#[argh(subcommand)]
command: HostSubcommand,
}
#[derive(FromArgs)]
#[argh(subcommand)]
enum HostSubcommand {
Sign(SignArgs),
}
/// Request a signed certificate for an SSH public key
#[derive(FromArgs)]
#[argh(subcommand, name = "sign")]
struct SignArgs {
/// path to destination SSH host certificate file
#[argh(option, short = 'c')]
output: Option<PathBuf>,
/// path to SSH host public key file to sign
#[argh(positional)]
pubkey: PathBuf,
}
#[derive(Debug, Serialize)]
struct Claims {
sub: String,
iss: String,
aud: String,
iat: u64,
nbf: u64,
exp: u64,
}
type MainResult = Result<(), Box<dyn std::error::Error>>;
#[tokio::main(flavor = "current_thread")]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_writer(std::io::stderr)
.init();
match inner_main().await {
Ok(_) => (),
Err(e) => {
eprintln!("{}", e);
std::process::exit(1);
}
}
}
async fn inner_main() -> MainResult {
let args: Args = argh::from_env();
match args.command {
Subcommand::Host(args) => host_cmd(args).await,
}
}
async fn host_cmd(args: HostArgs) -> MainResult {
match args.command {
HostSubcommand::Sign(args) => sign_key(args).await,
}
}
async fn sign_key(args: SignArgs) -> MainResult {
let url = match std::env::var("SSHCA_SERVER") {
Ok(v) => v,
Err(std::env::VarError::NotPresent) => {
return Err("SSHCA_SERVER environment variable is not set".into());
}
Err(std::env::VarError::NotUnicode(_)) => {
return Err("SSHCA_SERVER environment variable is invalid".into());
}
};
let Some(hostname) = get_hostname() else {
return Err("Hostname must be valid UTF-8".into());
};
let Some(machine_id) = get_machine_id() else {
return Err("Could not determine machine ID".into());
};
let pubkey = std::fs::read_to_string(&args.pubkey)?;
let mut url =
Url::parse(&url).map_err(|e| format!("Invalid URL: {}", e))?;
url.path_segments_mut()
.map_err(|_| "Invalid URL: missing host")?
.pop_if_empty()
.push("host")
.push("sign");
let form = Form::new().text("hostname", hostname.clone()).part(
"pubkey",
Part::bytes(pubkey.into_bytes()).file_name(
args.pubkey
.file_name()
.ok_or("Invalid public key file path")?
.to_str()
.ok_or("Invalid public key file path")?
.to_string(),
),
);
let token = get_token(
&hostname,
&machine_id,
url.host_str().ok_or("Invalid URL: missing host")?,
)?;
let client = reqwest::Client::new();
info!(
"Requesting SSH host certificate for {} with key {}",
hostname,
args.pubkey.display()
);
debug!("Request: POST {}", url);
let res = client
.post(url)
.header("Authorization", format!("Bearer {}", token))
.multipart(form)
.send()
.await?;
debug!("Response: {:?} {}", &res.version(), &res.status());
match res.error_for_status_ref() {
Ok(_) => (),
Err(e) if e.status() == Some(StatusCode::BAD_REQUEST) => {
let msg = res.text().await.unwrap_or_else(|e| e.to_string());
error!("{}: {}", e, msg);
return Err(format!("{}\n{}", e, msg).into());
}
Err(e) => {
error!("{}", e);
return Err(e.into());
}
};
let cert = res.text().await?;
if let Some(path) = args.output {
std::fs::write(path, cert)?;
} else {
println!("{}", cert);
}
Ok(())
}
fn get_hostname() -> Option<String> {
gethostname::gethostname().into_string().ok()
}
fn get_machine_id() -> Option<Uuid> {
match std::fs::read_to_string(RPI_SERIAL_PATH) {
Ok(s) => match Uuid::parse_str(&format!(
"{:0>32}",
s.trim_end_matches('\0')
)) {
Ok(u) => return Some(u),
Err(e) => {
debug!("Invalid UUID: {}", e);
}
},
Err(e) => {
debug!("Could not read Raspberry Pi serial number: {}", e)
}
};
match std::fs::read_to_string(DMI_UUID_PATH) {
Ok(s) => match Uuid::parse_str(s.trim_end()) {
Ok(u) => return Some(u),
Err(e) => {
debug!("Invalid UUID: {}", e);
}
},
Err(e) => {
debug!("Could not read DMI product UUID from sysfs: {}", e);
}
};
None
}
fn get_token(
hostname: &str,
machine_id: &Uuid,
server: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let now = time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)?
.as_secs();
let claims = Claims {
sub: hostname.into(),
iss: hostname.into(),
aud: server.split(':').next().unwrap().into(),
nbf: now - 60,
iat: now,
exp: now + 60,
};
let mut secret = [0u8; 32];
Argon2::default()
.hash_password_into(
machine_id.as_bytes(),
hostname.as_bytes(),
&mut secret,
)
.map_err(|e| e.to_string())?;
let key = EncodingKey::from_secret(&secret);
Ok(encode(&Default::default(), &claims, &key)?)
}