Initial commit
commit
cadc977700
|
@ -0,0 +1,10 @@
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
charset = utf-8
|
||||||
|
end_of_line = lf
|
||||||
|
insert_final_newline = true
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
|
||||||
|
[**.rs]
|
||||||
|
max_line_length = 79
|
|
@ -0,0 +1 @@
|
||||||
|
Cargo.lock -diff
|
|
@ -0,0 +1,4 @@
|
||||||
|
/target
|
||||||
|
/config.toml
|
||||||
|
/machine_ids.json
|
||||||
|
/host_ca_key*
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,33 @@
|
||||||
|
[package]
|
||||||
|
name = "sshca"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
argh = "0.1.12"
|
||||||
|
argon2 = { version = "0.5.2", default-features = false, features = ["alloc"] }
|
||||||
|
axum = { version = "0.6.20", features = ["multipart", "headers", "json"] }
|
||||||
|
dirs = "5.0.1"
|
||||||
|
jsonwebtoken = { version = "8.3.0", default-features = false }
|
||||||
|
rand_core = { version = "0.6.4", features = ["getrandom"] }
|
||||||
|
serde = { version = "1.0.190", features = ["derive"] }
|
||||||
|
serde_json = "1.0.108"
|
||||||
|
ssh-key = { version = "0.6.2", features = ["serde", "ed25519", "getrandom"] }
|
||||||
|
tokio = { version = "1.33.0", features = ["rt", "macros", "net", "signal", "fs", "io-util"] }
|
||||||
|
toml = "0.8.6"
|
||||||
|
tracing = { version = "0.1.40", features = ["log"] }
|
||||||
|
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||||
|
uuid = "1.5.0"
|
||||||
|
virt = { version = "0.3.1", optional = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["libvirt"]
|
||||||
|
libvirt = ["dep:virt"]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
form-data-builder = "1.0.1"
|
||||||
|
hyper = "0.14.27"
|
||||||
|
serial_test = "2.0.0"
|
||||||
|
tempfile = "3.8.1"
|
||||||
|
tower = { version = "0.4.13", features = ["util"] }
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
max_width = 79
|
|
@ -0,0 +1,136 @@
|
||||||
|
//! Authentication
|
||||||
|
//!
|
||||||
|
//! SSHCA authentication is handled using JWTs. Each request must include a
|
||||||
|
//! proper token, with the target system's hostname as the subject and issuer,
|
||||||
|
//! signed by that system's machine ID.
|
||||||
|
//!
|
||||||
|
//! To identify identify the subject of a token, use [`get_token_subject`].
|
||||||
|
//! This function will return the hostname specified in the token; use this
|
||||||
|
//! to look up the machine ID of the system. Then, use [`validate_token`] to
|
||||||
|
//! validate the token.
|
||||||
|
|
||||||
|
use argon2::Argon2;
|
||||||
|
use jsonwebtoken::errors::Result;
|
||||||
|
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
||||||
|
use rand_core::{OsRng, RngCore};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tracing::error;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
/// JWT Token Claims
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Claims {
|
||||||
|
/// Token subject (machine hostname)
|
||||||
|
pub sub: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the subject from the token
|
||||||
|
///
|
||||||
|
/// This function takes a JWT string and returns the value of the `sub` claim.
|
||||||
|
/// The signature of the token is not validated, but if the token has expired,
|
||||||
|
/// an error is returned.
|
||||||
|
pub fn get_token_subject(token: &str) -> Result<String> {
|
||||||
|
let mut v = Validation::new(Algorithm::HS256);
|
||||||
|
v.insecure_disable_signature_validation();
|
||||||
|
v.set_required_spec_claims(&["sub"]);
|
||||||
|
let k = DecodingKey::from_secret(b"");
|
||||||
|
let data: TokenData<Claims> = decode(token, &k, &v)?;
|
||||||
|
Ok(data.claims.sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a token
|
||||||
|
///
|
||||||
|
/// This function validates a JWT for the specified hostname and its
|
||||||
|
/// corresponding machine ID. The token is valid if it is issued for the
|
||||||
|
/// specified hostname by the same, has an audience matching the value of the
|
||||||
|
/// `service` argument, and is within its validity period (not before/expires).
|
||||||
|
/// The token must be signed with HMAC-SHA256 using the host's machine ID as
|
||||||
|
/// the secret key.
|
||||||
|
pub fn validate_token(
|
||||||
|
token: &str,
|
||||||
|
hostname: &str,
|
||||||
|
machine_id: &Uuid,
|
||||||
|
service: &str,
|
||||||
|
) -> Result<Claims> {
|
||||||
|
let mut v = Validation::new(Algorithm::HS256);
|
||||||
|
v.validate_nbf = true;
|
||||||
|
v.set_issuer(&[hostname]);
|
||||||
|
v.set_audience(&[service]);
|
||||||
|
let mut secret = [0u8; 32];
|
||||||
|
if let Err(e) = Argon2::default().hash_password_into(
|
||||||
|
machine_id.as_bytes(),
|
||||||
|
hostname.as_bytes(),
|
||||||
|
&mut secret,
|
||||||
|
) {
|
||||||
|
error!("Could not derive token secret: {}", e);
|
||||||
|
OsRng.fill_bytes(&mut secret);
|
||||||
|
}
|
||||||
|
let k = DecodingKey::from_secret(&secret);
|
||||||
|
let data: TokenData<Claims> = decode(token, &k, &v)?;
|
||||||
|
Ok(data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) mod test {
|
||||||
|
use std::time;
|
||||||
|
|
||||||
|
use jsonwebtoken::{encode, EncodingKey};
|
||||||
|
use serde::Serialize;
|
||||||
|
use uuid::uuid;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct TestClaims {
|
||||||
|
sub: String,
|
||||||
|
iss: String,
|
||||||
|
aud: String,
|
||||||
|
iat: u64,
|
||||||
|
nbf: u64,
|
||||||
|
exp: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn make_token(hostname: &str, machine_id: Uuid) -> String {
|
||||||
|
let now = time::SystemTime::now()
|
||||||
|
.duration_since(time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
let claims = TestClaims {
|
||||||
|
sub: hostname.into(),
|
||||||
|
iss: hostname.into(),
|
||||||
|
aud: "sshca.example.org".into(),
|
||||||
|
nbf: now - 60,
|
||||||
|
iat: now,
|
||||||
|
exp: now + 60,
|
||||||
|
};
|
||||||
|
let mut secret = [0u8; 32];
|
||||||
|
Argon2::default()
|
||||||
|
.hash_password_into(
|
||||||
|
machine_id.as_bytes(),
|
||||||
|
hostname.as_bytes(),
|
||||||
|
&mut secret,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let key = EncodingKey::from_secret(&secret);
|
||||||
|
encode(&Default::default(), &claims, &key).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_token_subject() {
|
||||||
|
let hostname = "web0.example.org";
|
||||||
|
let machine_id = uuid!("890ec6fe-e6c2-4524-bec0-d474af8aa506");
|
||||||
|
let token = make_token(hostname, machine_id);
|
||||||
|
let sub = get_token_subject(&token).unwrap();
|
||||||
|
assert_eq!(sub, hostname);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_token() {
|
||||||
|
let hostname = "file0.example.org";
|
||||||
|
|
||||||
|
let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9");
|
||||||
|
let token = make_token(hostname, machine_id);
|
||||||
|
validate_token(&token, hostname, &machine_id, "sshca.example.org")
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,171 @@
|
||||||
|
//! SSH certificate authority
|
||||||
|
use std::path::Path;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use ssh_key::certificate::{Builder, CertType};
|
||||||
|
use ssh_key::rand_core::OsRng;
|
||||||
|
use ssh_key::{Certificate, PrivateKey, PublicKey};
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum CertError {
|
||||||
|
SystemTime(std::time::SystemTimeError),
|
||||||
|
SshKey(ssh_key::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::time::SystemTimeError> for CertError {
|
||||||
|
fn from(e: std::time::SystemTimeError) -> Self {
|
||||||
|
Self::SystemTime(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ssh_key::Error> for CertError {
|
||||||
|
fn from(e: ssh_key::Error) -> Self {
|
||||||
|
Self::SshKey(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for CertError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::SystemTime(e) => write!(f, "Invalid time: {}", e),
|
||||||
|
Self::SshKey(e) => write!(f, "SSH key error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for CertError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::SystemTime(e) => Some(e),
|
||||||
|
Self::SshKey(e) => Some(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum LoadKeyError {
|
||||||
|
Utf8(std::str::Utf8Error),
|
||||||
|
Io(std::io::Error),
|
||||||
|
SshKey(ssh_key::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::str::Utf8Error> for LoadKeyError {
|
||||||
|
fn from(e: std::str::Utf8Error) -> Self {
|
||||||
|
Self::Utf8(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for LoadKeyError {
|
||||||
|
fn from(e: std::io::Error) -> Self {
|
||||||
|
Self::Io(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ssh_key::Error> for LoadKeyError {
|
||||||
|
fn from(e: ssh_key::Error) -> Self {
|
||||||
|
Self::SshKey(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for LoadKeyError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Utf8(e) => write!(f, "Invalid key: {}", e),
|
||||||
|
Self::Io(e) => write!(f, "Could not read key file: {}", e),
|
||||||
|
Self::SshKey(e) => write!(f, "Could not parse SSH key: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for LoadKeyError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Utf8(e) => Some(e),
|
||||||
|
Self::Io(e) => Some(e),
|
||||||
|
Self::SshKey(e) => Some(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load an SSH private key from a file
|
||||||
|
pub async fn load_private_key<P>(path: P) -> Result<PrivateKey, LoadKeyError>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
{
|
||||||
|
let mut data = Vec::new();
|
||||||
|
debug!("Loading private key from {}", path.as_ref().display());
|
||||||
|
let mut f = File::open(path).await?;
|
||||||
|
f.read_to_end(&mut data).await?;
|
||||||
|
parse_private_key(&data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an SSH private key from a slice of bytes
|
||||||
|
pub fn parse_private_key(data: &[u8]) -> Result<PrivateKey, LoadKeyError> {
|
||||||
|
Ok(PrivateKey::from_openssh(std::str::from_utf8(data)?)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an SSH public key from a slice of bytes
|
||||||
|
pub fn parse_public_key(data: &[u8]) -> Result<PublicKey, LoadKeyError> {
|
||||||
|
Ok(PublicKey::from_openssh(std::str::from_utf8(data)?)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a signed SSH certificate for a host public key
|
||||||
|
///
|
||||||
|
/// This function creates a signed certificate for an SSH host public
|
||||||
|
/// key. The certificate will be valid for the specified hostname and
|
||||||
|
/// any alias names provided.
|
||||||
|
pub fn sign_cert(
|
||||||
|
hostname: &str,
|
||||||
|
pubkey: &PublicKey,
|
||||||
|
duration: Duration,
|
||||||
|
privkey: &PrivateKey,
|
||||||
|
alias: &[&str],
|
||||||
|
) -> Result<Certificate, CertError> {
|
||||||
|
let now = SystemTime::now();
|
||||||
|
let not_before = now.duration_since(UNIX_EPOCH)?.as_secs();
|
||||||
|
let not_after = not_before + duration.as_secs();
|
||||||
|
|
||||||
|
let mut builder = Builder::new_with_random_nonce(
|
||||||
|
&mut OsRng, pubkey, not_before, not_after,
|
||||||
|
)?;
|
||||||
|
builder.cert_type(CertType::Host)?;
|
||||||
|
builder.valid_principal(hostname)?;
|
||||||
|
for a in alias {
|
||||||
|
builder.valid_principal(*a)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(builder.sign(privkey)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use ssh_key::Algorithm;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sign_cert() {
|
||||||
|
let ca_key =
|
||||||
|
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
|
||||||
|
let host_key =
|
||||||
|
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
|
||||||
|
let host_pub_key = host_key.public_key();
|
||||||
|
let duration = Duration::from_secs(86400 * 30);
|
||||||
|
let hostname = "cloud0.example.org";
|
||||||
|
let cert = sign_cert(
|
||||||
|
hostname,
|
||||||
|
&host_pub_key,
|
||||||
|
duration,
|
||||||
|
&ca_key,
|
||||||
|
&["nextcloud.example.org"],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let valid_principals = cert.valid_principals();
|
||||||
|
assert_eq!(valid_principals[0], "cloud0.example.org");
|
||||||
|
assert_eq!(valid_principals[1], "nextcloud.example.org");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,211 @@
|
||||||
|
//! Application configuration
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// Error returned by [`load_config`]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ConfigError {
|
||||||
|
Io(std::io::Error),
|
||||||
|
Toml(toml::de::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ConfigError {
|
||||||
|
fn from(e: std::io::Error) -> Self {
|
||||||
|
Self::Io(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<toml::de::Error> for ConfigError {
|
||||||
|
fn from(e: toml::de::Error) -> Self {
|
||||||
|
Self::Toml(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ConfigError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Io(e) => write!(f, "Could not read config file: {}", e),
|
||||||
|
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ConfigError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Io(e) => Some(e),
|
||||||
|
Self::Toml(e) => Some(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Host CA Configuration
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct HostCaConfig {
|
||||||
|
/// Path to the Host CA private key file
|
||||||
|
#[serde(default = "default_host_ca_key")]
|
||||||
|
pub private_key_file: PathBuf,
|
||||||
|
|
||||||
|
/// Duration of issued host certificates
|
||||||
|
#[serde(default = "default_host_cert_duration")]
|
||||||
|
pub cert_duration: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HostCaConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
private_key_file: default_host_ca_key(),
|
||||||
|
cert_duration: default_host_cert_duration(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CA configuration
|
||||||
|
#[derive(Debug, Default, Deserialize)]
|
||||||
|
pub struct CaConfig {
|
||||||
|
/// Host CA configuration
|
||||||
|
pub host: HostCaConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Defines a connection to a libvirt VM host
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct LibvirtConfig {
|
||||||
|
/// libvirt Connection URI
|
||||||
|
pub uri: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Top-level configuration structure
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Configuration {
|
||||||
|
/// List of libvirt connection options
|
||||||
|
#[serde(default)]
|
||||||
|
pub libvirt: Vec<LibvirtConfig>,
|
||||||
|
/// Path to the machine ID map JSON file
|
||||||
|
#[serde(default = "default_machine_ids")]
|
||||||
|
pub machine_ids: PathBuf,
|
||||||
|
/// CA configuration
|
||||||
|
#[serde(default)]
|
||||||
|
pub ca: CaConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Configuration {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
libvirt: vec![],
|
||||||
|
machine_ids: default_machine_ids(),
|
||||||
|
ca: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_config_path(basename: &str) -> PathBuf {
|
||||||
|
dirs::config_dir().map_or(PathBuf::from(basename), |mut p| {
|
||||||
|
p.push(env!("CARGO_PKG_NAME"));
|
||||||
|
p.push(basename);
|
||||||
|
p
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_machine_ids() -> PathBuf {
|
||||||
|
default_config_path("machine_ids.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_host_ca_key() -> PathBuf {
|
||||||
|
default_config_path("host-ca.key")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_host_cert_duration() -> u64 {
|
||||||
|
86400 * 30
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load configuration from a TOML file
|
||||||
|
///
|
||||||
|
/// If `path` is provided, the configuration will be loaded from the
|
||||||
|
/// TOML file at that location. If `path` is `None`, the path will be
|
||||||
|
/// inferred from the XDG Configuration directory (i.e.
|
||||||
|
/// `${XDG_CONFIG_HOME}/sshca/config.toml`).
|
||||||
|
///
|
||||||
|
/// If the configuration file does not exist, the default values will be
|
||||||
|
/// used. If any error is encountered while reading or parsing the
|
||||||
|
/// file, a [`ConfigError`] will be returned.
|
||||||
|
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
{
|
||||||
|
let path = match path {
|
||||||
|
Some(p) => PathBuf::from(p.as_ref()),
|
||||||
|
None => default_config_path("config.toml"),
|
||||||
|
};
|
||||||
|
debug!("Loading configuration from {}", path.display());
|
||||||
|
match std::fs::read_to_string(path) {
|
||||||
|
Ok(s) => Ok(toml::from_str(&s)?),
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::NotFound => {
|
||||||
|
Ok(Default::default())
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use serial_test::serial;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn test_default_config() {
|
||||||
|
std::env::remove_var("XDG_CONFIG_HOME");
|
||||||
|
std::env::set_var("HOME", "/home/user");
|
||||||
|
let config = Configuration::default();
|
||||||
|
assert_eq!(
|
||||||
|
config.machine_ids,
|
||||||
|
PathBuf::from("/home/user/.config/sshca/machine_ids.json"),
|
||||||
|
);
|
||||||
|
assert_eq!(config.libvirt.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn test_default_config_path() {
|
||||||
|
std::env::remove_var("XDG_CONFIG_HOME");
|
||||||
|
std::env::set_var("HOME", "/home/user");
|
||||||
|
let path = default_config_path("config.toml");
|
||||||
|
assert_eq!(
|
||||||
|
path,
|
||||||
|
PathBuf::from("/home/user/.config/sshca/config.toml"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn test_default_config_path_env() {
|
||||||
|
std::env::set_var("XDG_CONFIG_HOME", "/etc");
|
||||||
|
let path = default_config_path("config.toml");
|
||||||
|
assert_eq!(path, PathBuf::from("/etc/sshca/config.toml"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_toml() {
|
||||||
|
let config_toml = r#"
|
||||||
|
[[libvirt]]
|
||||||
|
uri = "qemu+ssh://vmhost0.example.org/system"
|
||||||
|
|
||||||
|
[[libvirt]]
|
||||||
|
uri = "qemu+ssh://vmhost1.example.org/system"
|
||||||
|
"#;
|
||||||
|
let config: Configuration = toml::from_str(config_toml).unwrap();
|
||||||
|
assert_eq!(config.libvirt.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
config.libvirt[0].uri,
|
||||||
|
"qemu+ssh://vmhost0.example.org/system"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.libvirt[1].uri,
|
||||||
|
"qemu+ssh://vmhost1.example.org/system"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
mod auth;
|
||||||
|
mod ca;
|
||||||
|
pub mod config;
|
||||||
|
#[cfg(feature = "libvirt")]
|
||||||
|
mod libvirt;
|
||||||
|
mod machine_id;
|
||||||
|
pub mod server;
|
|
@ -0,0 +1,53 @@
|
||||||
|
//! libvirt VM info
|
||||||
|
use tracing::error;
|
||||||
|
use virt::connect::Connect;
|
||||||
|
use virt::domain::Domain;
|
||||||
|
|
||||||
|
use crate::config::LibvirtConfig;
|
||||||
|
|
||||||
|
/// libvirt connection wrapper with auto-disconnect
|
||||||
|
///
|
||||||
|
/// This structure wraps a virConnect object and automatically calls
|
||||||
|
/// virConnectClose when it is dropped.
|
||||||
|
struct Connection {
|
||||||
|
conn: Connect,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Connection {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Err(e) = self.conn.close() {
|
||||||
|
error!("Failed to close libvirt connection: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the UUID of a libvirt VM
|
||||||
|
///
|
||||||
|
/// This function connects to the libvirt daemon specified in the provided
|
||||||
|
/// [`LibvirtConfig`] structure and attempts to get the UUID of the virtual
|
||||||
|
/// machine with the given name. If the connection cannot be established, or
|
||||||
|
/// no machine is found with that name, an error is returned.
|
||||||
|
pub fn get_machine_id(
|
||||||
|
name: &str,
|
||||||
|
config: &LibvirtConfig,
|
||||||
|
) -> Result<String, virt::error::Error> {
|
||||||
|
let conn = Connection {
|
||||||
|
conn: Connect::open_read_only(&config.uri)?,
|
||||||
|
};
|
||||||
|
let dom = Domain::lookup_by_name(&conn.conn, name)?;
|
||||||
|
dom.get_uuid_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_machine_id() {
|
||||||
|
let config = LibvirtConfig {
|
||||||
|
uri: "test:///default".into(),
|
||||||
|
};
|
||||||
|
let machine_id = get_machine_id("test", &config).unwrap();
|
||||||
|
assert_eq!(&machine_id, "6695eb01-f6a4-8304-79aa-97f2502e193f");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
//! Look up a known machine ID
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio::task;
|
||||||
|
use tracing::{debug, error};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::config::Configuration;
|
||||||
|
use crate::libvirt;
|
||||||
|
|
||||||
|
/// Look up the machine ID for a host
|
||||||
|
///
|
||||||
|
/// This function will return the machine ID for a known host, either from the
|
||||||
|
/// static map stored in a JSON file or from a libvirt VM host. If no location
|
||||||
|
/// has a record for the machine, `None` is returned.
|
||||||
|
pub async fn get_machine_id(
|
||||||
|
hostname: &str,
|
||||||
|
config: Arc<Configuration>,
|
||||||
|
) -> Option<Uuid> {
|
||||||
|
if let Some(v) = from_map(hostname, config.clone()).await {
|
||||||
|
return parse_uuid(&v);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "libvirt")]
|
||||||
|
if let Some(v) = from_libvirt(hostname, config.clone()).await {
|
||||||
|
return parse_uuid(&v);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a machine ID in the JSON map file
|
||||||
|
///
|
||||||
|
/// This function reads the machine ID map stored in the JSON file. If the
|
||||||
|
/// file could not be opened, is not a valid JSON object, or does not contain
|
||||||
|
/// an entry for the specified hostname, `None` is returned.
|
||||||
|
async fn from_map(
|
||||||
|
hostname: &str,
|
||||||
|
config: Arc<Configuration>,
|
||||||
|
) -> Option<String> {
|
||||||
|
let data = match tokio::fs::read_to_string(&config.machine_ids).await {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to read machine ID map file: {}", e);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let res =
|
||||||
|
task::spawn_blocking(move || serde_json::from_str::<Value>(&data))
|
||||||
|
.await;
|
||||||
|
let map = match res {
|
||||||
|
Ok(Ok(r)) => r,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
error!("Error parsing machine ID map: {}", e);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Unexpected error while parsing JSON: {}", e);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Some(map.as_object()?.get(hostname)?.as_str()?.to_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a machine ID on the configured libvirt VM hosts
|
||||||
|
///
|
||||||
|
/// This function iterates over the configured VM hosts and checks each one for
|
||||||
|
/// a domain matching the specified hostname. If no VM host is available or
|
||||||
|
/// none of the configured hosts have a matching domain, `None` is returned.
|
||||||
|
#[cfg(feature = "libvirt")]
|
||||||
|
async fn from_libvirt(
|
||||||
|
hostname: &str,
|
||||||
|
config: Arc<Configuration>,
|
||||||
|
) -> Option<String> {
|
||||||
|
let hostname = Arc::new(hostname.split('.').next().unwrap().to_string());
|
||||||
|
let res = task::spawn_blocking(move || {
|
||||||
|
for libvirt in &config.libvirt {
|
||||||
|
debug!("Checking {} for {}", libvirt.uri, hostname);
|
||||||
|
match libvirt::get_machine_id(&hostname, libvirt) {
|
||||||
|
Ok(v) => return Some(v),
|
||||||
|
Err(e) => {
|
||||||
|
debug!("libvirt error: {}", e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
match res {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Unexpected error while querying libvirt: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a UUID from a string
|
||||||
|
///
|
||||||
|
/// Returns `None` if the string does not contain a valid UUID.
|
||||||
|
fn parse_uuid(value: &str) -> Option<Uuid> {
|
||||||
|
Uuid::parse_str(value).map_or_else(
|
||||||
|
|e| {
|
||||||
|
error!("Invalid UUID: {}", e);
|
||||||
|
None
|
||||||
|
},
|
||||||
|
Some,
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use argh::FromArgs;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
use sshca::config;
|
||||||
|
use sshca::server;
|
||||||
|
|
||||||
|
/// Online SSH CA Service
|
||||||
|
#[derive(FromArgs)]
|
||||||
|
struct Args {
|
||||||
|
/// path to configuration file
|
||||||
|
#[argh(option, short = 'c')]
|
||||||
|
config_file: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// listen address
|
||||||
|
#[argh(option, short = 'l')]
|
||||||
|
listen_address: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main(flavor = "current_thread")]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(EnvFilter::from_default_env())
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// stop libvirt printing to stderr
|
||||||
|
virt::error::clear_error_callback();
|
||||||
|
|
||||||
|
let args: Args = argh::from_env();
|
||||||
|
|
||||||
|
let listen_address = args
|
||||||
|
.listen_address
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&"[::]:8087".into())
|
||||||
|
.parse()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = config::load_config(args.config_file.as_deref()).unwrap();
|
||||||
|
let app = server::make_app(config);
|
||||||
|
axum::Server::bind(&listen_address)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use axum::extract::multipart::{Multipart, MultipartError};
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use serde::Serialize;
|
||||||
|
use ssh_key::Algorithm;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::auth::Claims;
|
||||||
|
use crate::ca;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct SignKeyResponse {
|
||||||
|
success: bool,
|
||||||
|
errors: Vec<String>,
|
||||||
|
certificates: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum SignKeyError {
|
||||||
|
Multipart(MultipartError),
|
||||||
|
Cert(ca::CertError),
|
||||||
|
LoadPrivateKey(ca::LoadKeyError),
|
||||||
|
ParsePublicKey(ca::LoadKeyError),
|
||||||
|
UnsupportedAlgorithm(String),
|
||||||
|
NoKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MultipartError> for SignKeyError {
|
||||||
|
fn from(e: MultipartError) -> Self {
|
||||||
|
Self::Multipart(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ca::CertError> for SignKeyError {
|
||||||
|
fn from(e: ca::CertError) -> Self {
|
||||||
|
Self::Cert(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for SignKeyError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
match self {
|
||||||
|
Self::Multipart(e) => {
|
||||||
|
debug!("Error reading request: {}", e);
|
||||||
|
let body = e.to_string();
|
||||||
|
(StatusCode::BAD_REQUEST, body).into_response()
|
||||||
|
}
|
||||||
|
Self::Cert(e) => {
|
||||||
|
error!("Failed to sign certificate: {}", e);
|
||||||
|
let body = "Service Unavailable";
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
|
||||||
|
}
|
||||||
|
Self::LoadPrivateKey(e) => {
|
||||||
|
error!("Error loading CA private key: {}", e);
|
||||||
|
let body = "Service Unavailable";
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
|
||||||
|
}
|
||||||
|
Self::ParsePublicKey(e) => {
|
||||||
|
error!("Error parsing public keykey: {}", e);
|
||||||
|
let body = e.to_string();
|
||||||
|
(StatusCode::BAD_REQUEST, body).into_response()
|
||||||
|
}
|
||||||
|
Self::UnsupportedAlgorithm(a) => {
|
||||||
|
debug!("Requested certificate for unsupported key algorithm \"{}\"", a);
|
||||||
|
let body = format!("Unsupported key algorithm: {}", a);
|
||||||
|
(StatusCode::BAD_REQUEST, body).into_response()
|
||||||
|
}
|
||||||
|
Self::NoKey => {
|
||||||
|
debug!("No SSH public key provided in request");
|
||||||
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"No SSH public key provided in request",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct SignKeyRequest {
|
||||||
|
hostname: String,
|
||||||
|
pubkey: Vec<u8>,
|
||||||
|
aliases: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn sign_host_cert(
|
||||||
|
claims: Claims,
|
||||||
|
State(ctx): State<super::State>,
|
||||||
|
mut form: Multipart,
|
||||||
|
) -> Result<String, SignKeyError> {
|
||||||
|
let hostname = claims.sub;
|
||||||
|
let mut body = SignKeyRequest::default();
|
||||||
|
|
||||||
|
while let Some(field) = form.next_field().await? {
|
||||||
|
match field.name() {
|
||||||
|
Some("pubkey") => {
|
||||||
|
body.pubkey = field.bytes().await?.into();
|
||||||
|
}
|
||||||
|
Some("alias") => body.aliases.push(field.text().await?),
|
||||||
|
Some("hostname") => body.hostname = field.text().await?,
|
||||||
|
Some(n) => {
|
||||||
|
warn!("Client request included unsupported field {:?}", n);
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if body.pubkey.is_empty() {
|
||||||
|
return Err(SignKeyError::NoKey);
|
||||||
|
}
|
||||||
|
let aliases: Vec<_> = body.aliases.iter().map(String::as_ref).collect();
|
||||||
|
|
||||||
|
let config = &ctx.config;
|
||||||
|
let duration = Duration::from_secs(config.ca.host.cert_duration);
|
||||||
|
let privkey = ca::load_private_key(&config.ca.host.private_key_file)
|
||||||
|
.await
|
||||||
|
.map_err(SignKeyError::LoadPrivateKey)?;
|
||||||
|
|
||||||
|
let pubkey = ca::parse_public_key(&body.pubkey)
|
||||||
|
.map_err(SignKeyError::ParsePublicKey)?;
|
||||||
|
match pubkey.algorithm() {
|
||||||
|
Algorithm::Ecdsa { .. } => (),
|
||||||
|
Algorithm::Ed25519 => (),
|
||||||
|
Algorithm::Rsa { .. } => (),
|
||||||
|
_ => {
|
||||||
|
return Err(SignKeyError::UnsupportedAlgorithm(
|
||||||
|
pubkey.algorithm().as_str().into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
debug!(
|
||||||
|
"Signing {} key for {}",
|
||||||
|
pubkey.algorithm().as_str(),
|
||||||
|
hostname
|
||||||
|
);
|
||||||
|
let cert =
|
||||||
|
ca::sign_cert(&hostname, &pubkey, duration, &privkey, &aliases)?;
|
||||||
|
info!(
|
||||||
|
"Signed {} key for {}",
|
||||||
|
pubkey.algorithm().as_str(),
|
||||||
|
hostname
|
||||||
|
);
|
||||||
|
Ok(cert.to_openssh().map_err(ca::CertError::from)?)
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
mod host;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::async_trait;
|
||||||
|
use axum::extract::FromRequestParts;
|
||||||
|
use axum::headers::authorization::Bearer;
|
||||||
|
use axum::headers::{Authorization, Host};
|
||||||
|
use axum::http::request::Parts;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use axum::{RequestPartsExt, Router, TypedHeader};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::auth::{self, Claims};
|
||||||
|
use crate::config::Configuration;
|
||||||
|
use crate::machine_id;
|
||||||
|
|
||||||
|
struct Context {
|
||||||
|
config: Arc<Configuration>,
|
||||||
|
}
|
||||||
|
|
||||||
|
type State = Arc<Context>;
|
||||||
|
|
||||||
|
pub struct AuthError;
|
||||||
|
|
||||||
|
impl IntoResponse for AuthError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl FromRequestParts<Arc<Context>> for Claims {
|
||||||
|
type Rejection = AuthError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut Parts,
|
||||||
|
ctx: &State,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let TypedHeader(Authorization(bearer)) = parts
|
||||||
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
debug!("Failed to extract token from HTTP request: {}", e);
|
||||||
|
AuthError
|
||||||
|
})?;
|
||||||
|
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|
||||||
|
|_| "localhost".to_owned(),
|
||||||
|
|v| v.hostname().to_owned(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let hostname =
|
||||||
|
auth::get_token_subject(bearer.token()).map_err(|e| {
|
||||||
|
debug!("Could not get token subject: {}", e);
|
||||||
|
AuthError
|
||||||
|
})?;
|
||||||
|
let machine_id =
|
||||||
|
machine_id::get_machine_id(&hostname, ctx.config.clone())
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| {
|
||||||
|
debug!("No machine ID found for host {}", hostname);
|
||||||
|
AuthError
|
||||||
|
})?;
|
||||||
|
let claims = auth::validate_token(
|
||||||
|
bearer.token(),
|
||||||
|
&hostname,
|
||||||
|
&machine_id,
|
||||||
|
&host,
|
||||||
|
)
|
||||||
|
.map_err(|e| {
|
||||||
|
debug!("Invalid auth token: {}", e);
|
||||||
|
AuthError
|
||||||
|
})?;
|
||||||
|
debug!("Successfully authenticated request from host {}", hostname);
|
||||||
|
Ok(claims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_app(config: Configuration) -> Router {
|
||||||
|
let ctx = Arc::new(Context {
|
||||||
|
config: config.into(),
|
||||||
|
});
|
||||||
|
Router::new()
|
||||||
|
.route("/", get(|| async { "UP" }))
|
||||||
|
.route("/host/sign", post(host::sign_host_cert))
|
||||||
|
.with_state(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::Request;
|
||||||
|
use tower::ServiceExt;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_up() {
|
||||||
|
let app = make_app(Configuration::default());
|
||||||
|
let response = app
|
||||||
|
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||||
|
assert_eq!(&body[..], b"UP");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod setup;
|
||||||
|
pub mod token;
|
|
@ -0,0 +1,71 @@
|
||||||
|
use std::error::Error;
|
||||||
|
use std::io::prelude::*;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Once;
|
||||||
|
|
||||||
|
use rand_core::OsRng;
|
||||||
|
use ssh_key::{Algorithm, Fingerprint, PrivateKey, PublicKey};
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
use sshca::config::Configuration;
|
||||||
|
|
||||||
|
static INIT: Once = Once::new();
|
||||||
|
|
||||||
|
fn gen_machine_ids() -> Result<NamedTempFile, Box<dyn Error>> {
|
||||||
|
let f = NamedTempFile::new()?;
|
||||||
|
let map = serde_json::json!({
|
||||||
|
"test.example.org": "b75e9126-d73a-4ae0-9a0d-63cb3552e6cd",
|
||||||
|
});
|
||||||
|
serde_json::to_writer(&f, &map)?;
|
||||||
|
Ok(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_config(machine_ids: &Path, host_key: &Path) -> Configuration {
|
||||||
|
let mut config = Configuration {
|
||||||
|
machine_ids: machine_ids.to_str().unwrap().into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
config.ca.host.private_key_file = host_key.to_str().unwrap().into();
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_ca_key() -> Result<(NamedTempFile, PublicKey), Box<dyn Error>> {
|
||||||
|
let key = PrivateKey::random(&mut OsRng, Algorithm::Ed25519)?;
|
||||||
|
let mut f = NamedTempFile::new()?;
|
||||||
|
f.write_all(key.to_openssh(Default::default())?.as_bytes())?;
|
||||||
|
Ok((f, key.public_key().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn setup() -> Result<(TestContext, Configuration), Box<dyn Error>> {
|
||||||
|
INIT.call_once(|| {
|
||||||
|
tracing_subscriber::fmt::fmt()
|
||||||
|
.with_env_filter(EnvFilter::from("sshca=trace"))
|
||||||
|
.with_test_writer()
|
||||||
|
.init();
|
||||||
|
});
|
||||||
|
|
||||||
|
let machine_ids = gen_machine_ids()?;
|
||||||
|
let (host_key, host_key_pub) = gen_ca_key()?;
|
||||||
|
let config = gen_config(machine_ids.path(), host_key.path());
|
||||||
|
|
||||||
|
let ctx = TestContext {
|
||||||
|
machine_ids,
|
||||||
|
host_key,
|
||||||
|
host_key_pub,
|
||||||
|
};
|
||||||
|
Ok((ctx, config))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct TestContext {
|
||||||
|
machine_ids: NamedTempFile,
|
||||||
|
host_key: NamedTempFile,
|
||||||
|
host_key_pub: PublicKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestContext {
|
||||||
|
pub fn host_ca_fingerprint(&self) -> Fingerprint {
|
||||||
|
self.host_key_pub.fingerprint(Default::default())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
use std::time;
|
||||||
|
|
||||||
|
use argon2::Argon2;
|
||||||
|
use jsonwebtoken::{encode, EncodingKey};
|
||||||
|
use serde::Serialize;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct TestClaims {
|
||||||
|
sub: String,
|
||||||
|
iss: String,
|
||||||
|
aud: String,
|
||||||
|
iat: u64,
|
||||||
|
nbf: u64,
|
||||||
|
exp: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_token(hostname: &str, machine_id: Uuid) -> String {
|
||||||
|
let now = time::SystemTime::now()
|
||||||
|
.duration_since(time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
let claims = TestClaims {
|
||||||
|
sub: hostname.into(),
|
||||||
|
iss: hostname.into(),
|
||||||
|
aud: "sshca.example.org".into(),
|
||||||
|
nbf: now - 60,
|
||||||
|
iat: now,
|
||||||
|
exp: now + 60,
|
||||||
|
};
|
||||||
|
let mut secret = [0u8; 32];
|
||||||
|
Argon2::default()
|
||||||
|
.hash_password_into(
|
||||||
|
machine_id.as_bytes(),
|
||||||
|
hostname.as_bytes(),
|
||||||
|
&mut secret,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let key = EncodingKey::from_secret(&secret);
|
||||||
|
encode(&Default::default(), &claims, &key).unwrap()
|
||||||
|
}
|
|
@ -0,0 +1,209 @@
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{Request, StatusCode};
|
||||||
|
use form_data_builder::FormData;
|
||||||
|
use ssh_key::{Algorithm, Certificate};
|
||||||
|
use tower::ServiceExt;
|
||||||
|
use uuid::uuid;
|
||||||
|
|
||||||
|
use sshca::server::make_app;
|
||||||
|
|
||||||
|
use common::setup;
|
||||||
|
use common::token;
|
||||||
|
|
||||||
|
const ED25519_KEY: &str = concat!(
|
||||||
|
"ssh-ed25519 ",
|
||||||
|
"AAAAC3NzaC1lZDI1NTE5AAAAIAsFEmrNIoRHHUayEO0NdAIgtMvci/wME07h+A5XSNJy",
|
||||||
|
);
|
||||||
|
const DSA_KEY: &str = concat!(
|
||||||
|
"ssh-dss ",
|
||||||
|
"AAAAB3NzaC1kc3MAAACBALNAS+fZjaWt4q+MAgjf6HREFoYjgoSVJUUCtNmRGhND85msVtla",
|
||||||
|
"kll2gLzL6n6TWyiToARlThoTFu1ZDoGYauDL7iDXrGB6VWJEOQZ3TEMHLFYPziW02AbjR9GI",
|
||||||
|
"ptsF42D0bTTvvaIaBIhOTjWAUjuFIhAKhPkcj+udIcyH8CG1AAAAFQCpbXQSlxOvd4J92j2C",
|
||||||
|
"rWDYVGoK8wAAAIAfHiV6/glGZrDRztJmw1hfwbmiNPxaoSGkB+Necfkj0fZrlyLj8sLJIbGQ",
|
||||||
|
"w0dJMATZdRHw3Ql4R5IOu7sBfX1KQW++onT4ads/Xtl6vwfsjO2e/a6Y1ib9JCIOGJxNAAUC",
|
||||||
|
"JU0Fm0TSv2Nn6UTICAarp1eKALimqkvy1+ygBWjprgAAAIEAic5EpZH9wpgzvl9kPW531yrz",
|
||||||
|
"IOlCcXsJFPqQxUThrB2o1g3Rjpscd9kCw5UlPu6GGLk4aSN3UxeIKymTuKiEi7tvP1Tj/Bv5",
|
||||||
|
"tEc4rhfmrBAfAST09oRFDsELufsOAlTrJ0uk2LhtN14H1RBv9qPR5PQKTEYslyvXG1f8itNQ",
|
||||||
|
"YnQ="
|
||||||
|
);
|
||||||
|
|
||||||
|
fn make_test_request_body(key: &[u8], name: &str) -> (Body, String) {
|
||||||
|
let mut form = FormData::new(Vec::new());
|
||||||
|
form.write_file(
|
||||||
|
"pubkey",
|
||||||
|
key,
|
||||||
|
Some(name.as_ref()),
|
||||||
|
"application/octet-stream",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let content_type = form.content_type_header();
|
||||||
|
let body = Body::from(form.finish().unwrap());
|
||||||
|
(body, content_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_test_request(body: Body, content_type: &str) -> Request<Body> {
|
||||||
|
let hostname = "test.example.org";
|
||||||
|
let machine_id = uuid!("b75e9126-d73a-4ae0-9a0d-63cb3552e6cd");
|
||||||
|
let token = token::make_token(hostname, machine_id);
|
||||||
|
Request::builder()
|
||||||
|
.uri("/host/sign")
|
||||||
|
.method("POST")
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.header("Host", "sshca.example.org")
|
||||||
|
.header("Content-Type", content_type)
|
||||||
|
.body(body)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign() {
|
||||||
|
let (ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let (body, content_type) = make_test_request_body(
|
||||||
|
ED25519_KEY.as_bytes(),
|
||||||
|
"ssh_host_ed25519_key.pub",
|
||||||
|
);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
let cert = Certificate::from_openssh(std::str::from_utf8(&body).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(cert.algorithm(), Algorithm::Ed25519);
|
||||||
|
cert.validate(&[ctx.host_ca_fingerprint()]).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_invalid() {
|
||||||
|
let (_ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let (body, content_type) = make_test_request_body(
|
||||||
|
"this is not a valid openssh key".as_bytes(),
|
||||||
|
"ssh_host_ecdsa_key.pub",
|
||||||
|
);
|
||||||
|
let app = make_app(config);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
body,
|
||||||
|
concat!(
|
||||||
|
"Could not parse SSH key: ",
|
||||||
|
"Base64 encoding error: invalid Base64 encoding",
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_nokey() {
|
||||||
|
let (_ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let mut form = FormData::new(Vec::new());
|
||||||
|
let content_type = form.content_type_header();
|
||||||
|
let body = Body::from(form.finish().unwrap());
|
||||||
|
let app = make_app(config);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "No SSH public key provided in request",);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_mangled() {
|
||||||
|
let (_ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let mut form = FormData::new(Vec::new());
|
||||||
|
form.write_file(
|
||||||
|
"pubkey",
|
||||||
|
ED25519_KEY.as_bytes(),
|
||||||
|
Some("ssh_host_ed25519_key.pub".as_ref()),
|
||||||
|
"application/octet-stream",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let content_type = form.content_type_header();
|
||||||
|
let mut form_bytes = form.finish().unwrap();
|
||||||
|
form_bytes.truncate(19);
|
||||||
|
let body = Body::from(form_bytes);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "Error parsing `multipart/form-data` request",);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_bad_request() {
|
||||||
|
let (_ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let content_type = "text/plain";
|
||||||
|
let body = Body::from("test");
|
||||||
|
let req = make_test_request(body, content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "Invalid `boundary` for `multipart/form-data` request",);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_dsa() {
|
||||||
|
let (_ctx, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let (body, content_type) =
|
||||||
|
make_test_request_body(DSA_KEY.as_bytes(), "ssh_host_dsa_key.pub");
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "Unsupported key algorithm: ssh-dss");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_failure() {
|
||||||
|
let (_ctx, mut config) = setup::setup().await.unwrap();
|
||||||
|
config.ca.host.private_key_file = "bogus".into();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let (body, content_type) = make_test_request_body(
|
||||||
|
ED25519_KEY.as_bytes(),
|
||||||
|
"ssh_host_ed25519_key.pub",
|
||||||
|
);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "Service Unavailable");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sign_unauthorized() {
|
||||||
|
// Deliberately drop the TestContext so the machine ID file gets deleted,
|
||||||
|
// which will cause authentication to fail.
|
||||||
|
let (_, config) = setup::setup().await.unwrap();
|
||||||
|
|
||||||
|
let app = make_app(config);
|
||||||
|
let (body, content_type) = make_test_request_body(
|
||||||
|
ED25519_KEY.as_bytes(),
|
||||||
|
"ssh_host_ed25519_key.pub",
|
||||||
|
);
|
||||||
|
let req = make_test_request(body, &content_type);
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||||
|
assert_eq!(body, "Unauthorized");
|
||||||
|
}
|
Loading…
Reference in New Issue