Initial commit

master
Dustin 2023-11-04 13:52:16 -05:00
commit ac9681e0c3
19 changed files with 3419 additions and 0 deletions

10
.editorconfig Normal file
View File

@ -0,0 +1,10 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[**.rs]
max_line_length = 79

1
.gitattributes vendored Normal file
View File

@ -0,0 +1 @@
Cargo.lock -diff

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
/target
/config.toml
/machine_ids.json
/host_ca_key*

2058
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

33
Cargo.toml Normal file
View File

@ -0,0 +1,33 @@
[package]
name = "sshca"
version = "0.1.0"
edition = "2021"
[dependencies]
argh = "0.1.12"
argon2 = { version = "0.5.2", default-features = false, features = ["alloc"] }
axum = { version = "0.6.20", features = ["multipart", "headers", "json"] }
dirs = "5.0.1"
jsonwebtoken = { version = "8.3.0", default-features = false }
rand_core = { version = "0.6.4", features = ["getrandom"] }
serde = { version = "1.0.190", features = ["derive"] }
serde_json = "1.0.108"
ssh-key = { version = "0.6.2", features = ["serde", "ed25519", "getrandom"] }
tokio = { version = "1.33.0", features = ["rt", "macros", "net", "signal", "fs", "io-util"] }
toml = "0.8.6"
tracing = { version = "0.1.40", features = ["log"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
uuid = "1.5.0"
virt = { version = "0.3.1", optional = true }
[features]
default = ["libvirt"]
libvirt = ["dep:virt"]
[dev-dependencies]
form-data-builder = "1.0.1"
hyper = "0.14.27"
serial_test = "2.0.0"
tempfile = "3.8.1"
tower = { version = "0.4.13", features = ["util"] }

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
max_width = 79

136
src/auth.rs Normal file
View File

@ -0,0 +1,136 @@
//! Authentication
//!
//! SSHCA authentication is handled using JWTs. Each request must include a
//! proper token, with the target system's hostname as the subject and issuer,
//! signed by that system's machine ID.
//!
//! To identify identify the subject of a token, use [`get_token_subject`].
//! This function will return the hostname specified in the token; use this
//! to look up the machine ID of the system. Then, use [`validate_token`] to
//! validate the token.
use argon2::Argon2;
use jsonwebtoken::errors::Result;
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
use rand_core::{OsRng, RngCore};
use serde::Deserialize;
use tracing::error;
use uuid::Uuid;
/// JWT Token Claims
#[derive(Debug, Deserialize)]
pub struct Claims {
/// Token subject (machine hostname)
pub sub: String,
}
/// Extract the subject from the token
///
/// This function takes a JWT string and returns the value of the `sub` claim.
/// The signature of the token is not validated, but if the token has expired,
/// an error is returned.
pub fn get_token_subject(token: &str) -> Result<String> {
let mut v = Validation::new(Algorithm::HS256);
v.insecure_disable_signature_validation();
v.set_required_spec_claims(&["sub"]);
let k = DecodingKey::from_secret(b"");
let data: TokenData<Claims> = decode(token, &k, &v)?;
Ok(data.claims.sub)
}
/// Validate a token
///
/// This function validates a JWT for the specified hostname and its
/// corresponding machine ID. The token is valid if it is issued for the
/// specified hostname by the same, has an audience matching the value of the
/// `service` argument, and is within its validity period (not before/expires).
/// The token must be signed with HMAC-SHA256 using the host's machine ID as
/// the secret key.
pub fn validate_token(
token: &str,
hostname: &str,
machine_id: &Uuid,
service: &str,
) -> Result<Claims> {
let mut v = Validation::new(Algorithm::HS256);
v.validate_nbf = true;
v.set_issuer(&[hostname]);
v.set_audience(&[service]);
let mut secret = [0u8; 32];
if let Err(e) = Argon2::default().hash_password_into(
machine_id.as_bytes(),
hostname.as_bytes(),
&mut secret,
) {
error!("Could not derive token secret: {}", e);
OsRng.fill_bytes(&mut secret);
}
let k = DecodingKey::from_secret(&secret);
let data: TokenData<Claims> = decode(token, &k, &v)?;
Ok(data.claims)
}
#[cfg(test)]
pub(crate) mod test {
use std::time;
use jsonwebtoken::{encode, EncodingKey};
use serde::Serialize;
use uuid::uuid;
use super::*;
#[derive(Debug, Serialize)]
struct TestClaims {
sub: String,
iss: String,
aud: String,
iat: u64,
nbf: u64,
exp: u64,
}
pub(crate) fn make_token(hostname: &str, machine_id: Uuid) -> String {
let now = time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = TestClaims {
sub: hostname.into(),
iss: hostname.into(),
aud: "sshca.example.org".into(),
nbf: now - 60,
iat: now,
exp: now + 60,
};
let mut secret = [0u8; 32];
Argon2::default()
.hash_password_into(
machine_id.as_bytes(),
hostname.as_bytes(),
&mut secret,
)
.unwrap();
let key = EncodingKey::from_secret(&secret);
encode(&Default::default(), &claims, &key).unwrap()
}
#[test]
fn test_get_token_subject() {
let hostname = "web0.example.org";
let machine_id = uuid!("890ec6fe-e6c2-4524-bec0-d474af8aa506");
let token = make_token(hostname, machine_id);
let sub = get_token_subject(&token).unwrap();
assert_eq!(sub, hostname);
}
#[test]
fn test_validate_token() {
let hostname = "file0.example.org";
let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9");
let token = make_token(hostname, machine_id);
validate_token(&token, hostname, &machine_id, "sshca.example.org")
.unwrap();
}
}

171
src/ca.rs Normal file
View File

@ -0,0 +1,171 @@
//! SSH certificate authority
use std::path::Path;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ssh_key::certificate::{Builder, CertType};
use ssh_key::rand_core::OsRng;
use ssh_key::{Certificate, PrivateKey, PublicKey};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tracing::debug;
#[derive(Debug)]
pub enum CertError {
SystemTime(std::time::SystemTimeError),
SshKey(ssh_key::Error),
}
impl From<std::time::SystemTimeError> for CertError {
fn from(e: std::time::SystemTimeError) -> Self {
Self::SystemTime(e)
}
}
impl From<ssh_key::Error> for CertError {
fn from(e: ssh_key::Error) -> Self {
Self::SshKey(e)
}
}
impl std::fmt::Display for CertError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::SystemTime(e) => write!(f, "Invalid time: {}", e),
Self::SshKey(e) => write!(f, "SSH key error: {}", e),
}
}
}
impl std::error::Error for CertError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::SystemTime(e) => Some(e),
Self::SshKey(e) => Some(e),
}
}
}
#[derive(Debug)]
pub enum LoadKeyError {
Utf8(std::str::Utf8Error),
Io(std::io::Error),
SshKey(ssh_key::Error),
}
impl From<std::str::Utf8Error> for LoadKeyError {
fn from(e: std::str::Utf8Error) -> Self {
Self::Utf8(e)
}
}
impl From<std::io::Error> for LoadKeyError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<ssh_key::Error> for LoadKeyError {
fn from(e: ssh_key::Error) -> Self {
Self::SshKey(e)
}
}
impl std::fmt::Display for LoadKeyError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Utf8(e) => write!(f, "Invalid key: {}", e),
Self::Io(e) => write!(f, "Could not read key file: {}", e),
Self::SshKey(e) => write!(f, "Could not parse SSH key: {}", e),
}
}
}
impl std::error::Error for LoadKeyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Utf8(e) => Some(e),
Self::Io(e) => Some(e),
Self::SshKey(e) => Some(e),
}
}
}
/// Load an SSH private key from a file
pub async fn load_private_key<P>(path: P) -> Result<PrivateKey, LoadKeyError>
where
P: AsRef<Path>,
{
let mut data = Vec::new();
debug!("Loading private key from {}", path.as_ref().display());
let mut f = File::open(path).await?;
f.read_to_end(&mut data).await?;
parse_private_key(&data)
}
/// Parse an SSH private key from a slice of bytes
pub fn parse_private_key(data: &[u8]) -> Result<PrivateKey, LoadKeyError> {
Ok(PrivateKey::from_openssh(std::str::from_utf8(data)?)?)
}
/// Parse an SSH public key from a slice of bytes
pub fn parse_public_key(data: &[u8]) -> Result<PublicKey, LoadKeyError> {
Ok(PublicKey::from_openssh(std::str::from_utf8(data)?)?)
}
/// Create a signed SSH certificate for a host public key
///
/// This function creates a signed certificate for an SSH host public
/// key. The certificate will be valid for the specified hostname and
/// any alias names provided.
pub fn sign_cert(
hostname: &str,
pubkey: &PublicKey,
duration: Duration,
privkey: &PrivateKey,
alias: &[&str],
) -> Result<Certificate, CertError> {
let now = SystemTime::now();
let not_before = now.duration_since(UNIX_EPOCH)?.as_secs();
let not_after = not_before + duration.as_secs();
let mut builder = Builder::new_with_random_nonce(
&mut OsRng, pubkey, not_before, not_after,
)?;
builder.cert_type(CertType::Host)?;
builder.valid_principal(hostname)?;
for a in alias {
builder.valid_principal(*a)?;
}
Ok(builder.sign(privkey)?)
}
#[cfg(test)]
mod test {
use ssh_key::Algorithm;
use super::*;
#[test]
fn test_sign_cert() {
let ca_key =
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
let host_key =
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
let host_pub_key = host_key.public_key();
let duration = Duration::from_secs(86400 * 30);
let hostname = "cloud0.example.org";
let cert = sign_cert(
hostname,
&host_pub_key,
duration,
&ca_key,
&["nextcloud.example.org"],
)
.unwrap();
let valid_principals = cert.valid_principals();
assert_eq!(valid_principals[0], "cloud0.example.org");
assert_eq!(valid_principals[1], "nextcloud.example.org");
}
}

211
src/config.rs Normal file
View File

@ -0,0 +1,211 @@
//! Application configuration
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use serde::Deserialize;
use tracing::debug;
/// Error returned by [`load_config`]
#[derive(Debug)]
pub enum ConfigError {
Io(std::io::Error),
Toml(toml::de::Error),
}
impl From<std::io::Error> for ConfigError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<toml::de::Error> for ConfigError {
fn from(e: toml::de::Error) -> Self {
Self::Toml(e)
}
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "Could not read config file: {}", e),
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
}
}
}
impl std::error::Error for ConfigError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Toml(e) => Some(e),
}
}
}
/// Host CA Configuration
#[derive(Debug, Deserialize)]
pub struct HostCaConfig {
/// Path to the Host CA private key file
#[serde(default = "default_host_ca_key")]
pub private_key_file: PathBuf,
/// Duration of issued host certificates
#[serde(default = "default_host_cert_duration")]
pub cert_duration: u64,
}
impl Default for HostCaConfig {
fn default() -> Self {
Self {
private_key_file: default_host_ca_key(),
cert_duration: default_host_cert_duration(),
}
}
}
/// CA configuration
#[derive(Debug, Default, Deserialize)]
pub struct CaConfig {
/// Host CA configuration
pub host: HostCaConfig,
}
/// Defines a connection to a libvirt VM host
#[derive(Debug, Deserialize)]
pub struct LibvirtConfig {
/// libvirt Connection URI
pub uri: String,
}
/// Top-level configuration structure
#[derive(Debug, Deserialize)]
pub struct Configuration {
/// List of libvirt connection options
#[serde(default)]
pub libvirt: Vec<LibvirtConfig>,
/// Path to the machine ID map JSON file
#[serde(default = "default_machine_ids")]
pub machine_ids: PathBuf,
/// CA configuration
#[serde(default)]
pub ca: CaConfig,
}
impl Default for Configuration {
fn default() -> Self {
Self {
libvirt: vec![],
machine_ids: default_machine_ids(),
ca: Default::default(),
}
}
}
fn default_config_path(basename: &str) -> PathBuf {
dirs::config_dir().map_or(PathBuf::from(basename), |mut p| {
p.push(env!("CARGO_PKG_NAME"));
p.push(basename);
p
})
}
fn default_machine_ids() -> PathBuf {
default_config_path("machine_ids.json")
}
fn default_host_ca_key() -> PathBuf {
default_config_path("host-ca.key")
}
fn default_host_cert_duration() -> u64 {
86400 * 30
}
/// Load configuration from a TOML file
///
/// If `path` is provided, the configuration will be loaded from the
/// TOML file at that location. If `path` is `None`, the path will be
/// inferred from the XDG Configuration directory (i.e.
/// `${XDG_CONFIG_HOME}/sshca/config.toml`).
///
/// If the configuration file does not exist, the default values will be
/// used. If any error is encountered while reading or parsing the
/// file, a [`ConfigError`] will be returned.
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
where
P: AsRef<Path>,
{
let path = match path {
Some(p) => PathBuf::from(p.as_ref()),
None => default_config_path("config.toml"),
};
debug!("Loading configuration from {}", path.display());
match std::fs::read_to_string(path) {
Ok(s) => Ok(toml::from_str(&s)?),
Err(ref e) if e.kind() == ErrorKind::NotFound => {
Ok(Default::default())
}
Err(e) => Err(e.into()),
}
}
#[cfg(test)]
mod test {
use serial_test::serial;
use super::*;
#[test]
#[serial]
fn test_default_config() {
std::env::remove_var("XDG_CONFIG_HOME");
std::env::set_var("HOME", "/home/user");
let config = Configuration::default();
assert_eq!(
config.machine_ids,
PathBuf::from("/home/user/.config/sshca/machine_ids.json"),
);
assert_eq!(config.libvirt.len(), 0);
}
#[test]
#[serial]
fn test_default_config_path() {
std::env::remove_var("XDG_CONFIG_HOME");
std::env::set_var("HOME", "/home/user");
let path = default_config_path("config.toml");
assert_eq!(
path,
PathBuf::from("/home/user/.config/sshca/config.toml"),
);
}
#[test]
#[serial]
fn test_default_config_path_env() {
std::env::set_var("XDG_CONFIG_HOME", "/etc");
let path = default_config_path("config.toml");
assert_eq!(path, PathBuf::from("/etc/sshca/config.toml"));
}
#[test]
fn test_config_toml() {
let config_toml = r#"
[[libvirt]]
uri = "qemu+ssh://vmhost0.example.org/system"
[[libvirt]]
uri = "qemu+ssh://vmhost1.example.org/system"
"#;
let config: Configuration = toml::from_str(config_toml).unwrap();
assert_eq!(config.libvirt.len(), 2);
assert_eq!(
config.libvirt[0].uri,
"qemu+ssh://vmhost0.example.org/system"
);
assert_eq!(
config.libvirt[1].uri,
"qemu+ssh://vmhost1.example.org/system"
);
}
}

7
src/lib.rs Normal file
View File

@ -0,0 +1,7 @@
mod auth;
mod ca;
pub mod config;
#[cfg(feature = "libvirt")]
mod libvirt;
mod machine_id;
pub mod server;

53
src/libvirt.rs Normal file
View File

@ -0,0 +1,53 @@
//! libvirt VM info
use tracing::error;
use virt::connect::Connect;
use virt::domain::Domain;
use crate::config::LibvirtConfig;
/// libvirt connection wrapper with auto-disconnect
///
/// This structure wraps a virConnect object and automatically calls
/// virConnectClose when it is dropped.
struct Connection {
conn: Connect,
}
impl Drop for Connection {
fn drop(&mut self) {
if let Err(e) = self.conn.close() {
error!("Failed to close libvirt connection: {}", e);
}
}
}
/// Get the UUID of a libvirt VM
///
/// This function connects to the libvirt daemon specified in the provided
/// [`LibvirtConfig`] structure and attempts to get the UUID of the virtual
/// machine with the given name. If the connection cannot be established, or
/// no machine is found with that name, an error is returned.
pub fn get_machine_id(
name: &str,
config: &LibvirtConfig,
) -> Result<String, virt::error::Error> {
let conn = Connection {
conn: Connect::open_read_only(&config.uri)?,
};
let dom = Domain::lookup_by_name(&conn.conn, name)?;
dom.get_uuid_string()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_get_machine_id() {
let config = LibvirtConfig {
uri: "test:///default".into(),
};
let machine_id = get_machine_id("test", &config).unwrap();
assert_eq!(&machine_id, "6695eb01-f6a4-8304-79aa-97f2502e193f");
}
}

108
src/machine_id.rs Normal file
View File

@ -0,0 +1,108 @@
//! Look up a known machine ID
use std::sync::Arc;
use serde_json::Value;
use tokio::task;
use tracing::{debug, error};
use uuid::Uuid;
use crate::config::Configuration;
use crate::libvirt;
/// Look up the machine ID for a host
///
/// This function will return the machine ID for a known host, either from the
/// static map stored in a JSON file or from a libvirt VM host. If no location
/// has a record for the machine, `None` is returned.
pub async fn get_machine_id(
hostname: &str,
config: Arc<Configuration>,
) -> Option<Uuid> {
if let Some(v) = from_map(hostname, config.clone()).await {
return parse_uuid(&v);
}
#[cfg(feature = "libvirt")]
if let Some(v) = from_libvirt(hostname, config.clone()).await {
return parse_uuid(&v);
}
None
}
/// Look up a machine ID in the JSON map file
///
/// This function reads the machine ID map stored in the JSON file. If the
/// file could not be opened, is not a valid JSON object, or does not contain
/// an entry for the specified hostname, `None` is returned.
async fn from_map(
hostname: &str,
config: Arc<Configuration>,
) -> Option<String> {
let data = match tokio::fs::read_to_string(&config.machine_ids).await {
Ok(d) => d,
Err(e) => {
error!("Failed to read machine ID map file: {}", e);
return None;
}
};
let res =
task::spawn_blocking(move || serde_json::from_str::<Value>(&data))
.await;
let map = match res {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
error!("Error parsing machine ID map: {}", e);
return None;
}
Err(e) => {
error!("Unexpected error while parsing JSON: {}", e);
return None;
}
};
Some(map.as_object()?.get(hostname)?.as_str()?.to_owned())
}
/// Look up a machine ID on the configured libvirt VM hosts
///
/// This function iterates over the configured VM hosts and checks each one for
/// a domain matching the specified hostname. If no VM host is available or
/// none of the configured hosts have a matching domain, `None` is returned.
#[cfg(feature = "libvirt")]
async fn from_libvirt(
hostname: &str,
config: Arc<Configuration>,
) -> Option<String> {
let hostname = Arc::new(hostname.split('.').next().unwrap().to_string());
let res = task::spawn_blocking(move || {
for libvirt in &config.libvirt {
debug!("Checking {} for {}", libvirt.uri, hostname);
match libvirt::get_machine_id(&hostname, libvirt) {
Ok(v) => return Some(v),
Err(e) => {
debug!("libvirt error: {}", e);
}
};
}
None
})
.await;
match res {
Ok(v) => v,
Err(e) => {
error!("Unexpected error while querying libvirt: {}", e);
None
}
}
}
/// Parse a UUID from a string
///
/// Returns `None` if the string does not contain a valid UUID.
fn parse_uuid(value: &str) -> Option<Uuid> {
Uuid::parse_str(value).map_or_else(
|e| {
error!("Invalid UUID: {}", e);
None
},
Some,
)
}

46
src/main.rs Normal file
View File

@ -0,0 +1,46 @@
use std::path::PathBuf;
use argh::FromArgs;
use tracing_subscriber::EnvFilter;
use sshca::config;
use sshca::server;
/// Online SSH CA Service
#[derive(FromArgs)]
struct Args {
/// path to configuration file
#[argh(option, short = 'c')]
config_file: Option<PathBuf>,
/// listen address
#[argh(option, short = 'l')]
listen_address: Option<String>,
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_writer(std::io::stderr)
.init();
// stop libvirt printing to stderr
virt::error::clear_error_callback();
let args: Args = argh::from_env();
let listen_address = args
.listen_address
.as_ref()
.unwrap_or(&"[::]:8087".into())
.parse()
.unwrap();
let config = config::load_config(args.config_file.as_deref()).unwrap();
let app = server::make_app(config);
axum::Server::bind(&listen_address)
.serve(app.into_make_service())
.await
.unwrap();
}

147
src/server/host.rs Normal file
View File

@ -0,0 +1,147 @@
use std::collections::HashMap;
use std::time::Duration;
use axum::extract::multipart::{Multipart, MultipartError};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
use ssh_key::Algorithm;
use tracing::{debug, error, info, warn};
use crate::auth::Claims;
use crate::ca;
#[derive(Serialize)]
pub struct SignKeyResponse {
success: bool,
errors: Vec<String>,
certificates: HashMap<String, String>,
}
pub enum SignKeyError {
Multipart(MultipartError),
Cert(ca::CertError),
LoadPrivateKey(ca::LoadKeyError),
ParsePublicKey(ca::LoadKeyError),
UnsupportedAlgorithm(String),
NoKey,
}
impl From<MultipartError> for SignKeyError {
fn from(e: MultipartError) -> Self {
Self::Multipart(e)
}
}
impl From<ca::CertError> for SignKeyError {
fn from(e: ca::CertError) -> Self {
Self::Cert(e)
}
}
impl IntoResponse for SignKeyError {
fn into_response(self) -> Response {
match self {
Self::Multipart(e) => {
debug!("Error reading request: {}", e);
let body = e.to_string();
(StatusCode::BAD_REQUEST, body).into_response()
}
Self::Cert(e) => {
error!("Failed to sign certificate: {}", e);
let body = "Service Unavailable";
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
}
Self::LoadPrivateKey(e) => {
error!("Error loading CA private key: {}", e);
let body = "Service Unavailable";
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
}
Self::ParsePublicKey(e) => {
error!("Error parsing public keykey: {}", e);
let body = e.to_string();
(StatusCode::BAD_REQUEST, body).into_response()
}
Self::UnsupportedAlgorithm(a) => {
debug!("Requested certificate for unsupported key algorithm \"{}\"", a);
let body = format!("Unsupported key algorithm: {}", a);
(StatusCode::BAD_REQUEST, body).into_response()
}
Self::NoKey => {
debug!("No SSH public key provided in request");
(
StatusCode::BAD_REQUEST,
"No SSH public key provided in request",
)
.into_response()
}
}
}
}
#[derive(Default)]
struct SignKeyRequest {
hostname: String,
pubkey: Vec<u8>,
aliases: Vec<String>,
}
pub(super) async fn sign_host_cert(
claims: Claims,
State(ctx): State<super::State>,
mut form: Multipart,
) -> Result<String, SignKeyError> {
let hostname = claims.sub;
let mut body = SignKeyRequest::default();
while let Some(field) = form.next_field().await? {
match field.name() {
Some("pubkey") => {
body.pubkey = field.bytes().await?.into();
}
Some("alias") => body.aliases.push(field.text().await?),
Some("hostname") => body.hostname = field.text().await?,
Some(n) => {
warn!("Client request included unsupported field {:?}", n);
}
None => {}
}
}
if body.pubkey.is_empty() {
return Err(SignKeyError::NoKey);
}
let aliases: Vec<_> = body.aliases.iter().map(String::as_ref).collect();
let config = &ctx.config;
let duration = Duration::from_secs(config.ca.host.cert_duration);
let privkey = ca::load_private_key(&config.ca.host.private_key_file)
.await
.map_err(SignKeyError::LoadPrivateKey)?;
let pubkey = ca::parse_public_key(&body.pubkey)
.map_err(SignKeyError::ParsePublicKey)?;
match pubkey.algorithm() {
Algorithm::Ecdsa { .. } => (),
Algorithm::Ed25519 => (),
Algorithm::Rsa { .. } => (),
_ => {
return Err(SignKeyError::UnsupportedAlgorithm(
pubkey.algorithm().as_str().into(),
));
}
};
debug!(
"Signing {} key for {}",
pubkey.algorithm().as_str(),
hostname
);
let cert =
ca::sign_cert(&hostname, &pubkey, duration, &privkey, &aliases)?;
info!(
"Signed {} key for {}",
pubkey.algorithm().as_str(),
hostname
);
Ok(cert.to_openssh().map_err(ca::CertError::from)?)
}

110
src/server/mod.rs Normal file
View File

@ -0,0 +1,110 @@
mod host;
use std::sync::Arc;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::headers::authorization::Bearer;
use axum::headers::{Authorization, Host};
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{RequestPartsExt, Router, TypedHeader};
use tracing::debug;
use crate::auth::{self, Claims};
use crate::config::Configuration;
use crate::machine_id;
struct Context {
config: Arc<Configuration>,
}
type State = Arc<Context>;
pub struct AuthError;
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
}
}
#[async_trait]
impl FromRequestParts<Arc<Context>> for Claims {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
ctx: &State,
) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| {
debug!("Failed to extract token from HTTP request: {}", e);
AuthError
})?;
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|_| "localhost".to_owned(),
|v| v.hostname().to_owned(),
);
let hostname =
auth::get_token_subject(bearer.token()).map_err(|e| {
debug!("Could not get token subject: {}", e);
AuthError
})?;
let machine_id =
machine_id::get_machine_id(&hostname, ctx.config.clone())
.await
.ok_or_else(|| {
debug!("No machine ID found for host {}", hostname);
AuthError
})?;
let claims = auth::validate_token(
bearer.token(),
&hostname,
&machine_id,
&host,
)
.map_err(|e| {
debug!("Invalid auth token: {}", e);
AuthError
})?;
debug!("Successfully authenticated request from host {}", hostname);
Ok(claims)
}
}
pub fn make_app(config: Configuration) -> Router {
let ctx = Arc::new(Context {
config: config.into(),
});
Router::new()
.route("/", get(|| async { "UP" }))
.route("/host/sign", post(host::sign_host_cert))
.with_state(ctx)
}
#[cfg(test)]
mod test {
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
use super::*;
#[tokio::test]
async fn test_up() {
let app = make_app(Configuration::default());
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"UP");
}
}

2
tests/common/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod setup;
pub mod token;

71
tests/common/setup.rs Normal file
View File

@ -0,0 +1,71 @@
use std::error::Error;
use std::io::prelude::*;
use std::path::Path;
use std::sync::Once;
use rand_core::OsRng;
use ssh_key::{Algorithm, Fingerprint, PrivateKey, PublicKey};
use tempfile::NamedTempFile;
use tracing_subscriber::EnvFilter;
use sshca::config::Configuration;
static INIT: Once = Once::new();
fn gen_machine_ids() -> Result<NamedTempFile, Box<dyn Error>> {
let f = NamedTempFile::new()?;
let map = serde_json::json!({
"test.example.org": "b75e9126-d73a-4ae0-9a0d-63cb3552e6cd",
});
serde_json::to_writer(&f, &map)?;
Ok(f)
}
fn gen_config(machine_ids: &Path, host_key: &Path) -> Configuration {
let mut config = Configuration {
machine_ids: machine_ids.to_str().unwrap().into(),
..Default::default()
};
config.ca.host.private_key_file = host_key.to_str().unwrap().into();
config
}
fn gen_ca_key() -> Result<(NamedTempFile, PublicKey), Box<dyn Error>> {
let key = PrivateKey::random(&mut OsRng, Algorithm::Ed25519)?;
let mut f = NamedTempFile::new()?;
f.write_all(key.to_openssh(Default::default())?.as_bytes())?;
Ok((f, key.public_key().clone()))
}
pub async fn setup() -> Result<(TestContext, Configuration), Box<dyn Error>> {
INIT.call_once(|| {
tracing_subscriber::fmt::fmt()
.with_env_filter(EnvFilter::from("sshca=trace"))
.with_test_writer()
.init();
});
let machine_ids = gen_machine_ids()?;
let (host_key, host_key_pub) = gen_ca_key()?;
let config = gen_config(machine_ids.path(), host_key.path());
let ctx = TestContext {
machine_ids,
host_key,
host_key_pub,
};
Ok((ctx, config))
}
#[allow(dead_code)]
pub struct TestContext {
machine_ids: NamedTempFile,
host_key: NamedTempFile,
host_key_pub: PublicKey,
}
impl TestContext {
pub fn host_ca_fingerprint(&self) -> Fingerprint {
self.host_key_pub.fingerprint(Default::default())
}
}

41
tests/common/token.rs Normal file
View File

@ -0,0 +1,41 @@
use std::time;
use argon2::Argon2;
use jsonwebtoken::{encode, EncodingKey};
use serde::Serialize;
use uuid::Uuid;
#[derive(Debug, Serialize)]
struct TestClaims {
sub: String,
iss: String,
aud: String,
iat: u64,
nbf: u64,
exp: u64,
}
pub fn make_token(hostname: &str, machine_id: Uuid) -> String {
let now = time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = TestClaims {
sub: hostname.into(),
iss: hostname.into(),
aud: "sshca.example.org".into(),
nbf: now - 60,
iat: now,
exp: now + 60,
};
let mut secret = [0u8; 32];
Argon2::default()
.hash_password_into(
machine_id.as_bytes(),
hostname.as_bytes(),
&mut secret,
)
.unwrap();
let key = EncodingKey::from_secret(&secret);
encode(&Default::default(), &claims, &key).unwrap()
}

209
tests/test_host.rs Normal file
View File

@ -0,0 +1,209 @@
mod common;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use form_data_builder::FormData;
use ssh_key::{Algorithm, Certificate};
use tower::ServiceExt;
use uuid::uuid;
use sshca::server::make_app;
use common::setup;
use common::token;
const ED25519_KEY: &str = concat!(
"ssh-ed25519 ",
"AAAAC3NzaC1lZDI1NTE5AAAAIAsFEmrNIoRHHUayEO0NdAIgtMvci/wME07h+A5XSNJy",
);
const DSA_KEY: &str = concat!(
"ssh-dss ",
"AAAAB3NzaC1kc3MAAACBALNAS+fZjaWt4q+MAgjf6HREFoYjgoSVJUUCtNmRGhND85msVtla",
"kll2gLzL6n6TWyiToARlThoTFu1ZDoGYauDL7iDXrGB6VWJEOQZ3TEMHLFYPziW02AbjR9GI",
"ptsF42D0bTTvvaIaBIhOTjWAUjuFIhAKhPkcj+udIcyH8CG1AAAAFQCpbXQSlxOvd4J92j2C",
"rWDYVGoK8wAAAIAfHiV6/glGZrDRztJmw1hfwbmiNPxaoSGkB+Necfkj0fZrlyLj8sLJIbGQ",
"w0dJMATZdRHw3Ql4R5IOu7sBfX1KQW++onT4ads/Xtl6vwfsjO2e/a6Y1ib9JCIOGJxNAAUC",
"JU0Fm0TSv2Nn6UTICAarp1eKALimqkvy1+ygBWjprgAAAIEAic5EpZH9wpgzvl9kPW531yrz",
"IOlCcXsJFPqQxUThrB2o1g3Rjpscd9kCw5UlPu6GGLk4aSN3UxeIKymTuKiEi7tvP1Tj/Bv5",
"tEc4rhfmrBAfAST09oRFDsELufsOAlTrJ0uk2LhtN14H1RBv9qPR5PQKTEYslyvXG1f8itNQ",
"YnQ="
);
fn make_test_request_body(key: &[u8], name: &str) -> (Body, String) {
let mut form = FormData::new(Vec::new());
form.write_file(
"pubkey",
key,
Some(name.as_ref()),
"application/octet-stream",
)
.unwrap();
let content_type = form.content_type_header();
let body = Body::from(form.finish().unwrap());
(body, content_type)
}
fn make_test_request(body: Body, content_type: &str) -> Request<Body> {
let hostname = "test.example.org";
let machine_id = uuid!("b75e9126-d73a-4ae0-9a0d-63cb3552e6cd");
let token = token::make_token(hostname, machine_id);
Request::builder()
.uri("/host/sign")
.method("POST")
.header("Authorization", format!("Bearer {}", token))
.header("Host", "sshca.example.org")
.header("Content-Type", content_type)
.body(body)
.unwrap()
}
#[tokio::test]
async fn test_sign() {
let (ctx, config) = setup::setup().await.unwrap();
let app = make_app(config);
let (body, content_type) = make_test_request_body(
ED25519_KEY.as_bytes(),
"ssh_host_ed25519_key.pub",
);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
let cert = Certificate::from_openssh(std::str::from_utf8(&body).unwrap())
.unwrap();
assert_eq!(cert.algorithm(), Algorithm::Ed25519);
cert.validate(&[ctx.host_ca_fingerprint()]).unwrap();
}
#[tokio::test]
async fn test_sign_invalid() {
let (_ctx, config) = setup::setup().await.unwrap();
let (body, content_type) = make_test_request_body(
"this is not a valid openssh key".as_bytes(),
"ssh_host_ecdsa_key.pub",
);
let app = make_app(config);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(
body,
concat!(
"Could not parse SSH key: ",
"Base64 encoding error: invalid Base64 encoding",
)
);
}
#[tokio::test]
async fn test_sign_nokey() {
let (_ctx, config) = setup::setup().await.unwrap();
let mut form = FormData::new(Vec::new());
let content_type = form.content_type_header();
let body = Body::from(form.finish().unwrap());
let app = make_app(config);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "No SSH public key provided in request",);
}
#[tokio::test]
async fn test_sign_mangled() {
let (_ctx, config) = setup::setup().await.unwrap();
let app = make_app(config);
let mut form = FormData::new(Vec::new());
form.write_file(
"pubkey",
ED25519_KEY.as_bytes(),
Some("ssh_host_ed25519_key.pub".as_ref()),
"application/octet-stream",
)
.unwrap();
let content_type = form.content_type_header();
let mut form_bytes = form.finish().unwrap();
form_bytes.truncate(19);
let body = Body::from(form_bytes);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "Error parsing `multipart/form-data` request",);
}
#[tokio::test]
async fn test_sign_bad_request() {
let (_ctx, config) = setup::setup().await.unwrap();
let app = make_app(config);
let content_type = "text/plain";
let body = Body::from("test");
let req = make_test_request(body, content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "Invalid `boundary` for `multipart/form-data` request",);
}
#[tokio::test]
async fn test_sign_dsa() {
let (_ctx, config) = setup::setup().await.unwrap();
let app = make_app(config);
let (body, content_type) =
make_test_request_body(DSA_KEY.as_bytes(), "ssh_host_dsa_key.pub");
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "Unsupported key algorithm: ssh-dss");
}
#[tokio::test]
async fn test_sign_failure() {
let (_ctx, mut config) = setup::setup().await.unwrap();
config.ca.host.private_key_file = "bogus".into();
let app = make_app(config);
let (body, content_type) = make_test_request_body(
ED25519_KEY.as_bytes(),
"ssh_host_ed25519_key.pub",
);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "Service Unavailable");
}
#[tokio::test]
async fn test_sign_unauthorized() {
// Deliberately drop the TestContext so the machine ID file gets deleted,
// which will cause authentication to fail.
let (_, config) = setup::setup().await.unwrap();
let app = make_app(config);
let (body, content_type) = make_test_request_body(
ED25519_KEY.as_bytes(),
"ssh_host_ed25519_key.pub",
);
let req = make_test_request(body, &content_type);
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body, "Unauthorized");
}