Initial commit
commit
cadc977700
|
@ -0,0 +1,10 @@
|
|||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[**.rs]
|
||||
max_line_length = 79
|
|
@ -0,0 +1 @@
|
|||
Cargo.lock -diff
|
|
@ -0,0 +1,4 @@
|
|||
/target
|
||||
/config.toml
|
||||
/machine_ids.json
|
||||
/host_ca_key*
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,33 @@
|
|||
[package]
|
||||
name = "sshca"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
argh = "0.1.12"
|
||||
argon2 = { version = "0.5.2", default-features = false, features = ["alloc"] }
|
||||
axum = { version = "0.6.20", features = ["multipart", "headers", "json"] }
|
||||
dirs = "5.0.1"
|
||||
jsonwebtoken = { version = "8.3.0", default-features = false }
|
||||
rand_core = { version = "0.6.4", features = ["getrandom"] }
|
||||
serde = { version = "1.0.190", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
||||
ssh-key = { version = "0.6.2", features = ["serde", "ed25519", "getrandom"] }
|
||||
tokio = { version = "1.33.0", features = ["rt", "macros", "net", "signal", "fs", "io-util"] }
|
||||
toml = "0.8.6"
|
||||
tracing = { version = "0.1.40", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||
uuid = "1.5.0"
|
||||
virt = { version = "0.3.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = ["libvirt"]
|
||||
libvirt = ["dep:virt"]
|
||||
|
||||
[dev-dependencies]
|
||||
form-data-builder = "1.0.1"
|
||||
hyper = "0.14.27"
|
||||
serial_test = "2.0.0"
|
||||
tempfile = "3.8.1"
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
|
|
@ -0,0 +1 @@
|
|||
max_width = 79
|
|
@ -0,0 +1,136 @@
|
|||
//! Authentication
|
||||
//!
|
||||
//! SSHCA authentication is handled using JWTs. Each request must include a
|
||||
//! proper token, with the target system's hostname as the subject and issuer,
|
||||
//! signed by that system's machine ID.
|
||||
//!
|
||||
//! To identify identify the subject of a token, use [`get_token_subject`].
|
||||
//! This function will return the hostname specified in the token; use this
|
||||
//! to look up the machine ID of the system. Then, use [`validate_token`] to
|
||||
//! validate the token.
|
||||
|
||||
use argon2::Argon2;
|
||||
use jsonwebtoken::errors::Result;
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use serde::Deserialize;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// JWT Token Claims
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Claims {
|
||||
/// Token subject (machine hostname)
|
||||
pub sub: String,
|
||||
}
|
||||
|
||||
/// Extract the subject from the token
|
||||
///
|
||||
/// This function takes a JWT string and returns the value of the `sub` claim.
|
||||
/// The signature of the token is not validated, but if the token has expired,
|
||||
/// an error is returned.
|
||||
pub fn get_token_subject(token: &str) -> Result<String> {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
v.insecure_disable_signature_validation();
|
||||
v.set_required_spec_claims(&["sub"]);
|
||||
let k = DecodingKey::from_secret(b"");
|
||||
let data: TokenData<Claims> = decode(token, &k, &v)?;
|
||||
Ok(data.claims.sub)
|
||||
}
|
||||
|
||||
/// Validate a token
|
||||
///
|
||||
/// This function validates a JWT for the specified hostname and its
|
||||
/// corresponding machine ID. The token is valid if it is issued for the
|
||||
/// specified hostname by the same, has an audience matching the value of the
|
||||
/// `service` argument, and is within its validity period (not before/expires).
|
||||
/// The token must be signed with HMAC-SHA256 using the host's machine ID as
|
||||
/// the secret key.
|
||||
pub fn validate_token(
|
||||
token: &str,
|
||||
hostname: &str,
|
||||
machine_id: &Uuid,
|
||||
service: &str,
|
||||
) -> Result<Claims> {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
v.validate_nbf = true;
|
||||
v.set_issuer(&[hostname]);
|
||||
v.set_audience(&[service]);
|
||||
let mut secret = [0u8; 32];
|
||||
if let Err(e) = Argon2::default().hash_password_into(
|
||||
machine_id.as_bytes(),
|
||||
hostname.as_bytes(),
|
||||
&mut secret,
|
||||
) {
|
||||
error!("Could not derive token secret: {}", e);
|
||||
OsRng.fill_bytes(&mut secret);
|
||||
}
|
||||
let k = DecodingKey::from_secret(&secret);
|
||||
let data: TokenData<Claims> = decode(token, &k, &v)?;
|
||||
Ok(data.claims)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use std::time;
|
||||
|
||||
use jsonwebtoken::{encode, EncodingKey};
|
||||
use serde::Serialize;
|
||||
use uuid::uuid;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TestClaims {
|
||||
sub: String,
|
||||
iss: String,
|
||||
aud: String,
|
||||
iat: u64,
|
||||
nbf: u64,
|
||||
exp: u64,
|
||||
}
|
||||
|
||||
pub(crate) fn make_token(hostname: &str, machine_id: Uuid) -> String {
|
||||
let now = time::SystemTime::now()
|
||||
.duration_since(time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let claims = TestClaims {
|
||||
sub: hostname.into(),
|
||||
iss: hostname.into(),
|
||||
aud: "sshca.example.org".into(),
|
||||
nbf: now - 60,
|
||||
iat: now,
|
||||
exp: now + 60,
|
||||
};
|
||||
let mut secret = [0u8; 32];
|
||||
Argon2::default()
|
||||
.hash_password_into(
|
||||
machine_id.as_bytes(),
|
||||
hostname.as_bytes(),
|
||||
&mut secret,
|
||||
)
|
||||
.unwrap();
|
||||
let key = EncodingKey::from_secret(&secret);
|
||||
encode(&Default::default(), &claims, &key).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_token_subject() {
|
||||
let hostname = "web0.example.org";
|
||||
let machine_id = uuid!("890ec6fe-e6c2-4524-bec0-d474af8aa506");
|
||||
let token = make_token(hostname, machine_id);
|
||||
let sub = get_token_subject(&token).unwrap();
|
||||
assert_eq!(sub, hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token() {
|
||||
let hostname = "file0.example.org";
|
||||
|
||||
let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9");
|
||||
let token = make_token(hostname, machine_id);
|
||||
validate_token(&token, hostname, &machine_id, "sshca.example.org")
|
||||
.unwrap();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
//! SSH certificate authority
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use ssh_key::certificate::{Builder, CertType};
|
||||
use ssh_key::rand_core::OsRng;
|
||||
use ssh_key::{Certificate, PrivateKey, PublicKey};
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CertError {
|
||||
SystemTime(std::time::SystemTimeError),
|
||||
SshKey(ssh_key::Error),
|
||||
}
|
||||
|
||||
impl From<std::time::SystemTimeError> for CertError {
|
||||
fn from(e: std::time::SystemTimeError) -> Self {
|
||||
Self::SystemTime(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ssh_key::Error> for CertError {
|
||||
fn from(e: ssh_key::Error) -> Self {
|
||||
Self::SshKey(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CertError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::SystemTime(e) => write!(f, "Invalid time: {}", e),
|
||||
Self::SshKey(e) => write!(f, "SSH key error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CertError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::SystemTime(e) => Some(e),
|
||||
Self::SshKey(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum LoadKeyError {
|
||||
Utf8(std::str::Utf8Error),
|
||||
Io(std::io::Error),
|
||||
SshKey(ssh_key::Error),
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for LoadKeyError {
|
||||
fn from(e: std::str::Utf8Error) -> Self {
|
||||
Self::Utf8(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for LoadKeyError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
Self::Io(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ssh_key::Error> for LoadKeyError {
|
||||
fn from(e: ssh_key::Error) -> Self {
|
||||
Self::SshKey(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LoadKeyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Utf8(e) => write!(f, "Invalid key: {}", e),
|
||||
Self::Io(e) => write!(f, "Could not read key file: {}", e),
|
||||
Self::SshKey(e) => write!(f, "Could not parse SSH key: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for LoadKeyError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Utf8(e) => Some(e),
|
||||
Self::Io(e) => Some(e),
|
||||
Self::SshKey(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load an SSH private key from a file
|
||||
pub async fn load_private_key<P>(path: P) -> Result<PrivateKey, LoadKeyError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let mut data = Vec::new();
|
||||
debug!("Loading private key from {}", path.as_ref().display());
|
||||
let mut f = File::open(path).await?;
|
||||
f.read_to_end(&mut data).await?;
|
||||
parse_private_key(&data)
|
||||
}
|
||||
|
||||
/// Parse an SSH private key from a slice of bytes
|
||||
pub fn parse_private_key(data: &[u8]) -> Result<PrivateKey, LoadKeyError> {
|
||||
Ok(PrivateKey::from_openssh(std::str::from_utf8(data)?)?)
|
||||
}
|
||||
|
||||
/// Parse an SSH public key from a slice of bytes
|
||||
pub fn parse_public_key(data: &[u8]) -> Result<PublicKey, LoadKeyError> {
|
||||
Ok(PublicKey::from_openssh(std::str::from_utf8(data)?)?)
|
||||
}
|
||||
|
||||
/// Create a signed SSH certificate for a host public key
|
||||
///
|
||||
/// This function creates a signed certificate for an SSH host public
|
||||
/// key. The certificate will be valid for the specified hostname and
|
||||
/// any alias names provided.
|
||||
pub fn sign_cert(
|
||||
hostname: &str,
|
||||
pubkey: &PublicKey,
|
||||
duration: Duration,
|
||||
privkey: &PrivateKey,
|
||||
alias: &[&str],
|
||||
) -> Result<Certificate, CertError> {
|
||||
let now = SystemTime::now();
|
||||
let not_before = now.duration_since(UNIX_EPOCH)?.as_secs();
|
||||
let not_after = not_before + duration.as_secs();
|
||||
|
||||
let mut builder = Builder::new_with_random_nonce(
|
||||
&mut OsRng, pubkey, not_before, not_after,
|
||||
)?;
|
||||
builder.cert_type(CertType::Host)?;
|
||||
builder.valid_principal(hostname)?;
|
||||
for a in alias {
|
||||
builder.valid_principal(*a)?;
|
||||
}
|
||||
|
||||
Ok(builder.sign(privkey)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use ssh_key::Algorithm;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sign_cert() {
|
||||
let ca_key =
|
||||
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
|
||||
let host_key =
|
||||
PrivateKey::random(&mut OsRng, Algorithm::Ed25519).unwrap();
|
||||
let host_pub_key = host_key.public_key();
|
||||
let duration = Duration::from_secs(86400 * 30);
|
||||
let hostname = "cloud0.example.org";
|
||||
let cert = sign_cert(
|
||||
hostname,
|
||||
&host_pub_key,
|
||||
duration,
|
||||
&ca_key,
|
||||
&["nextcloud.example.org"],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let valid_principals = cert.valid_principals();
|
||||
assert_eq!(valid_principals[0], "cloud0.example.org");
|
||||
assert_eq!(valid_principals[1], "nextcloud.example.org");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
//! Application configuration
|
||||
use std::io::ErrorKind;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
/// Error returned by [`load_config`]
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
Io(std::io::Error),
|
||||
Toml(toml::de::Error),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ConfigError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
Self::Io(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<toml::de::Error> for ConfigError {
|
||||
fn from(e: toml::de::Error) -> Self {
|
||||
Self::Toml(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(e) => write!(f, "Could not read config file: {}", e),
|
||||
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ConfigError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Io(e) => Some(e),
|
||||
Self::Toml(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Host CA Configuration
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct HostCaConfig {
|
||||
/// Path to the Host CA private key file
|
||||
#[serde(default = "default_host_ca_key")]
|
||||
pub private_key_file: PathBuf,
|
||||
|
||||
/// Duration of issued host certificates
|
||||
#[serde(default = "default_host_cert_duration")]
|
||||
pub cert_duration: u64,
|
||||
}
|
||||
|
||||
impl Default for HostCaConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
private_key_file: default_host_ca_key(),
|
||||
cert_duration: default_host_cert_duration(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CA configuration
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
pub struct CaConfig {
|
||||
/// Host CA configuration
|
||||
pub host: HostCaConfig,
|
||||
}
|
||||
|
||||
/// Defines a connection to a libvirt VM host
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LibvirtConfig {
|
||||
/// libvirt Connection URI
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// Top-level configuration structure
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Configuration {
|
||||
/// List of libvirt connection options
|
||||
#[serde(default)]
|
||||
pub libvirt: Vec<LibvirtConfig>,
|
||||
/// Path to the machine ID map JSON file
|
||||
#[serde(default = "default_machine_ids")]
|
||||
pub machine_ids: PathBuf,
|
||||
/// CA configuration
|
||||
#[serde(default)]
|
||||
pub ca: CaConfig,
|
||||
}
|
||||
|
||||
impl Default for Configuration {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
libvirt: vec![],
|
||||
machine_ids: default_machine_ids(),
|
||||
ca: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_config_path(basename: &str) -> PathBuf {
|
||||
dirs::config_dir().map_or(PathBuf::from(basename), |mut p| {
|
||||
p.push(env!("CARGO_PKG_NAME"));
|
||||
p.push(basename);
|
||||
p
|
||||
})
|
||||
}
|
||||
|
||||
fn default_machine_ids() -> PathBuf {
|
||||
default_config_path("machine_ids.json")
|
||||
}
|
||||
|
||||
fn default_host_ca_key() -> PathBuf {
|
||||
default_config_path("host-ca.key")
|
||||
}
|
||||
|
||||
fn default_host_cert_duration() -> u64 {
|
||||
86400 * 30
|
||||
}
|
||||
|
||||
/// Load configuration from a TOML file
|
||||
///
|
||||
/// If `path` is provided, the configuration will be loaded from the
|
||||
/// TOML file at that location. If `path` is `None`, the path will be
|
||||
/// inferred from the XDG Configuration directory (i.e.
|
||||
/// `${XDG_CONFIG_HOME}/sshca/config.toml`).
|
||||
///
|
||||
/// If the configuration file does not exist, the default values will be
|
||||
/// used. If any error is encountered while reading or parsing the
|
||||
/// file, a [`ConfigError`] will be returned.
|
||||
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let path = match path {
|
||||
Some(p) => PathBuf::from(p.as_ref()),
|
||||
None => default_config_path("config.toml"),
|
||||
};
|
||||
debug!("Loading configuration from {}", path.display());
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(s) => Ok(toml::from_str(&s)?),
|
||||
Err(ref e) if e.kind() == ErrorKind::NotFound => {
|
||||
Ok(Default::default())
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use serial_test::serial;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn test_default_config() {
|
||||
std::env::remove_var("XDG_CONFIG_HOME");
|
||||
std::env::set_var("HOME", "/home/user");
|
||||
let config = Configuration::default();
|
||||
assert_eq!(
|
||||
config.machine_ids,
|
||||
PathBuf::from("/home/user/.config/sshca/machine_ids.json"),
|
||||
);
|
||||
assert_eq!(config.libvirt.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn test_default_config_path() {
|
||||
std::env::remove_var("XDG_CONFIG_HOME");
|
||||
std::env::set_var("HOME", "/home/user");
|
||||
let path = default_config_path("config.toml");
|
||||
assert_eq!(
|
||||
path,
|
||||
PathBuf::from("/home/user/.config/sshca/config.toml"),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn test_default_config_path_env() {
|
||||
std::env::set_var("XDG_CONFIG_HOME", "/etc");
|
||||
let path = default_config_path("config.toml");
|
||||
assert_eq!(path, PathBuf::from("/etc/sshca/config.toml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_toml() {
|
||||
let config_toml = r#"
|
||||
[[libvirt]]
|
||||
uri = "qemu+ssh://vmhost0.example.org/system"
|
||||
|
||||
[[libvirt]]
|
||||
uri = "qemu+ssh://vmhost1.example.org/system"
|
||||
"#;
|
||||
let config: Configuration = toml::from_str(config_toml).unwrap();
|
||||
assert_eq!(config.libvirt.len(), 2);
|
||||
assert_eq!(
|
||||
config.libvirt[0].uri,
|
||||
"qemu+ssh://vmhost0.example.org/system"
|
||||
);
|
||||
assert_eq!(
|
||||
config.libvirt[1].uri,
|
||||
"qemu+ssh://vmhost1.example.org/system"
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod auth;
|
||||
mod ca;
|
||||
pub mod config;
|
||||
#[cfg(feature = "libvirt")]
|
||||
mod libvirt;
|
||||
mod machine_id;
|
||||
pub mod server;
|
|
@ -0,0 +1,53 @@
|
|||
//! libvirt VM info
|
||||
use tracing::error;
|
||||
use virt::connect::Connect;
|
||||
use virt::domain::Domain;
|
||||
|
||||
use crate::config::LibvirtConfig;
|
||||
|
||||
/// libvirt connection wrapper with auto-disconnect
|
||||
///
|
||||
/// This structure wraps a virConnect object and automatically calls
|
||||
/// virConnectClose when it is dropped.
|
||||
struct Connection {
|
||||
conn: Connect,
|
||||
}
|
||||
|
||||
impl Drop for Connection {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = self.conn.close() {
|
||||
error!("Failed to close libvirt connection: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the UUID of a libvirt VM
|
||||
///
|
||||
/// This function connects to the libvirt daemon specified in the provided
|
||||
/// [`LibvirtConfig`] structure and attempts to get the UUID of the virtual
|
||||
/// machine with the given name. If the connection cannot be established, or
|
||||
/// no machine is found with that name, an error is returned.
|
||||
pub fn get_machine_id(
|
||||
name: &str,
|
||||
config: &LibvirtConfig,
|
||||
) -> Result<String, virt::error::Error> {
|
||||
let conn = Connection {
|
||||
conn: Connect::open_read_only(&config.uri)?,
|
||||
};
|
||||
let dom = Domain::lookup_by_name(&conn.conn, name)?;
|
||||
dom.get_uuid_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_machine_id() {
|
||||
let config = LibvirtConfig {
|
||||
uri: "test:///default".into(),
|
||||
};
|
||||
let machine_id = get_machine_id("test", &config).unwrap();
|
||||
assert_eq!(&machine_id, "6695eb01-f6a4-8304-79aa-97f2502e193f");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
//! Look up a known machine ID
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::task;
|
||||
use tracing::{debug, error};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Configuration;
|
||||
use crate::libvirt;
|
||||
|
||||
/// Look up the machine ID for a host
|
||||
///
|
||||
/// This function will return the machine ID for a known host, either from the
|
||||
/// static map stored in a JSON file or from a libvirt VM host. If no location
|
||||
/// has a record for the machine, `None` is returned.
|
||||
pub async fn get_machine_id(
|
||||
hostname: &str,
|
||||
config: Arc<Configuration>,
|
||||
) -> Option<Uuid> {
|
||||
if let Some(v) = from_map(hostname, config.clone()).await {
|
||||
return parse_uuid(&v);
|
||||
}
|
||||
#[cfg(feature = "libvirt")]
|
||||
if let Some(v) = from_libvirt(hostname, config.clone()).await {
|
||||
return parse_uuid(&v);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Look up a machine ID in the JSON map file
|
||||
///
|
||||
/// This function reads the machine ID map stored in the JSON file. If the
|
||||
/// file could not be opened, is not a valid JSON object, or does not contain
|
||||
/// an entry for the specified hostname, `None` is returned.
|
||||
async fn from_map(
|
||||
hostname: &str,
|
||||
config: Arc<Configuration>,
|
||||
) -> Option<String> {
|
||||
let data = match tokio::fs::read_to_string(&config.machine_ids).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
error!("Failed to read machine ID map file: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let res =
|
||||
task::spawn_blocking(move || serde_json::from_str::<Value>(&data))
|
||||
.await;
|
||||
let map = match res {
|
||||
Ok(Ok(r)) => r,
|
||||
Ok(Err(e)) => {
|
||||
error!("Error parsing machine ID map: {}", e);
|
||||
return None;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Unexpected error while parsing JSON: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(map.as_object()?.get(hostname)?.as_str()?.to_owned())
|
||||
}
|
||||
|
||||
/// Look up a machine ID on the configured libvirt VM hosts
|
||||
///
|
||||
/// This function iterates over the configured VM hosts and checks each one for
|
||||
/// a domain matching the specified hostname. If no VM host is available or
|
||||
/// none of the configured hosts have a matching domain, `None` is returned.
|
||||
#[cfg(feature = "libvirt")]
|
||||
async fn from_libvirt(
|
||||
hostname: &str,
|
||||
config: Arc<Configuration>,
|
||||
) -> Option<String> {
|
||||
let hostname = Arc::new(hostname.split('.').next().unwrap().to_string());
|
||||
let res = task::spawn_blocking(move || {
|
||||
for libvirt in &config.libvirt {
|
||||
debug!("Checking {} for {}", libvirt.uri, hostname);
|
||||
match libvirt::get_machine_id(&hostname, libvirt) {
|
||||
Ok(v) => return Some(v),
|
||||
Err(e) => {
|
||||
debug!("libvirt error: {}", e);
|
||||
}
|
||||
};
|
||||
}
|
||||
None
|
||||
})
|
||||
.await;
|
||||
match res {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("Unexpected error while querying libvirt: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a UUID from a string
|
||||
///
|
||||
/// Returns `None` if the string does not contain a valid UUID.
|
||||
fn parse_uuid(value: &str) -> Option<Uuid> {
|
||||
Uuid::parse_str(value).map_or_else(
|
||||
|e| {
|
||||
error!("Invalid UUID: {}", e);
|
||||
None
|
||||
},
|
||||
Some,
|
||||
)
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use argh::FromArgs;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use sshca::config;
|
||||
use sshca::server;
|
||||
|
||||
/// Online SSH CA Service
|
||||
#[derive(FromArgs)]
|
||||
struct Args {
|
||||
/// path to configuration file
|
||||
#[argh(option, short = 'c')]
|
||||
config_file: Option<PathBuf>,
|
||||
|
||||
/// listen address
|
||||
#[argh(option, short = 'l')]
|
||||
listen_address: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
// stop libvirt printing to stderr
|
||||
virt::error::clear_error_callback();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
let listen_address = args
|
||||
.listen_address
|
||||
.as_ref()
|
||||
.unwrap_or(&"[::]:8087".into())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
let config = config::load_config(args.config_file.as_deref()).unwrap();
|
||||
let app = server::make_app(config);
|
||||
axum::Server::bind(&listen_address)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::multipart::{Multipart, MultipartError};
|
||||
use axum::extract::State;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::Serialize;
|
||||
use ssh_key::Algorithm;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::Claims;
|
||||
use crate::ca;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SignKeyResponse {
|
||||
success: bool,
|
||||
errors: Vec<String>,
|
||||
certificates: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub enum SignKeyError {
|
||||
Multipart(MultipartError),
|
||||
Cert(ca::CertError),
|
||||
LoadPrivateKey(ca::LoadKeyError),
|
||||
ParsePublicKey(ca::LoadKeyError),
|
||||
UnsupportedAlgorithm(String),
|
||||
NoKey,
|
||||
}
|
||||
|
||||
impl From<MultipartError> for SignKeyError {
|
||||
fn from(e: MultipartError) -> Self {
|
||||
Self::Multipart(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ca::CertError> for SignKeyError {
|
||||
fn from(e: ca::CertError) -> Self {
|
||||
Self::Cert(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for SignKeyError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Self::Multipart(e) => {
|
||||
debug!("Error reading request: {}", e);
|
||||
let body = e.to_string();
|
||||
(StatusCode::BAD_REQUEST, body).into_response()
|
||||
}
|
||||
Self::Cert(e) => {
|
||||
error!("Failed to sign certificate: {}", e);
|
||||
let body = "Service Unavailable";
|
||||
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
|
||||
}
|
||||
Self::LoadPrivateKey(e) => {
|
||||
error!("Error loading CA private key: {}", e);
|
||||
let body = "Service Unavailable";
|
||||
(StatusCode::SERVICE_UNAVAILABLE, body).into_response()
|
||||
}
|
||||
Self::ParsePublicKey(e) => {
|
||||
error!("Error parsing public keykey: {}", e);
|
||||
let body = e.to_string();
|
||||
(StatusCode::BAD_REQUEST, body).into_response()
|
||||
}
|
||||
Self::UnsupportedAlgorithm(a) => {
|
||||
debug!("Requested certificate for unsupported key algorithm \"{}\"", a);
|
||||
let body = format!("Unsupported key algorithm: {}", a);
|
||||
(StatusCode::BAD_REQUEST, body).into_response()
|
||||
}
|
||||
Self::NoKey => {
|
||||
debug!("No SSH public key provided in request");
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No SSH public key provided in request",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SignKeyRequest {
|
||||
hostname: String,
|
||||
pubkey: Vec<u8>,
|
||||
aliases: Vec<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn sign_host_cert(
|
||||
claims: Claims,
|
||||
State(ctx): State<super::State>,
|
||||
mut form: Multipart,
|
||||
) -> Result<String, SignKeyError> {
|
||||
let hostname = claims.sub;
|
||||
let mut body = SignKeyRequest::default();
|
||||
|
||||
while let Some(field) = form.next_field().await? {
|
||||
match field.name() {
|
||||
Some("pubkey") => {
|
||||
body.pubkey = field.bytes().await?.into();
|
||||
}
|
||||
Some("alias") => body.aliases.push(field.text().await?),
|
||||
Some("hostname") => body.hostname = field.text().await?,
|
||||
Some(n) => {
|
||||
warn!("Client request included unsupported field {:?}", n);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
if body.pubkey.is_empty() {
|
||||
return Err(SignKeyError::NoKey);
|
||||
}
|
||||
let aliases: Vec<_> = body.aliases.iter().map(String::as_ref).collect();
|
||||
|
||||
let config = &ctx.config;
|
||||
let duration = Duration::from_secs(config.ca.host.cert_duration);
|
||||
let privkey = ca::load_private_key(&config.ca.host.private_key_file)
|
||||
.await
|
||||
.map_err(SignKeyError::LoadPrivateKey)?;
|
||||
|
||||
let pubkey = ca::parse_public_key(&body.pubkey)
|
||||
.map_err(SignKeyError::ParsePublicKey)?;
|
||||
match pubkey.algorithm() {
|
||||
Algorithm::Ecdsa { .. } => (),
|
||||
Algorithm::Ed25519 => (),
|
||||
Algorithm::Rsa { .. } => (),
|
||||
_ => {
|
||||
return Err(SignKeyError::UnsupportedAlgorithm(
|
||||
pubkey.algorithm().as_str().into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"Signing {} key for {}",
|
||||
pubkey.algorithm().as_str(),
|
||||
hostname
|
||||
);
|
||||
let cert =
|
||||
ca::sign_cert(&hostname, &pubkey, duration, &privkey, &aliases)?;
|
||||
info!(
|
||||
"Signed {} key for {}",
|
||||
pubkey.algorithm().as_str(),
|
||||
hostname
|
||||
);
|
||||
Ok(cert.to_openssh().map_err(ca::CertError::from)?)
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
mod host;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::async_trait;
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::headers::authorization::Bearer;
|
||||
use axum::headers::{Authorization, Host};
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{RequestPartsExt, Router, TypedHeader};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::auth::{self, Claims};
|
||||
use crate::config::Configuration;
|
||||
use crate::machine_id;
|
||||
|
||||
struct Context {
|
||||
config: Arc<Configuration>,
|
||||
}
|
||||
|
||||
type State = Arc<Context>;
|
||||
|
||||
pub struct AuthError;
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<Arc<Context>> for Claims {
|
||||
type Rejection = AuthError;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
ctx: &State,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let TypedHeader(Authorization(bearer)) = parts
|
||||
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
debug!("Failed to extract token from HTTP request: {}", e);
|
||||
AuthError
|
||||
})?;
|
||||
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|
||||
|_| "localhost".to_owned(),
|
||||
|v| v.hostname().to_owned(),
|
||||
);
|
||||
|
||||
let hostname =
|
||||
auth::get_token_subject(bearer.token()).map_err(|e| {
|
||||
debug!("Could not get token subject: {}", e);
|
||||
AuthError
|
||||
})?;
|
||||
let machine_id =
|
||||
machine_id::get_machine_id(&hostname, ctx.config.clone())
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
debug!("No machine ID found for host {}", hostname);
|
||||
AuthError
|
||||
})?;
|
||||
let claims = auth::validate_token(
|
||||
bearer.token(),
|
||||
&hostname,
|
||||
&machine_id,
|
||||
&host,
|
||||
)
|
||||
.map_err(|e| {
|
||||
debug!("Invalid auth token: {}", e);
|
||||
AuthError
|
||||
})?;
|
||||
debug!("Successfully authenticated request from host {}", hostname);
|
||||
Ok(claims)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_app(config: Configuration) -> Router {
|
||||
let ctx = Arc::new(Context {
|
||||
config: config.into(),
|
||||
});
|
||||
Router::new()
|
||||
.route("/", get(|| async { "UP" }))
|
||||
.route("/host/sign", post(host::sign_host_cert))
|
||||
.with_state(ctx)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use axum::body::Body;
|
||||
use axum::http::Request;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_up() {
|
||||
let app = make_app(Configuration::default());
|
||||
let response = app
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
assert_eq!(&body[..], b"UP");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
pub mod setup;
|
||||
pub mod token;
|
|
@ -0,0 +1,71 @@
|
|||
use std::error::Error;
|
||||
use std::io::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::sync::Once;
|
||||
|
||||
use rand_core::OsRng;
|
||||
use ssh_key::{Algorithm, Fingerprint, PrivateKey, PublicKey};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use sshca::config::Configuration;
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
fn gen_machine_ids() -> Result<NamedTempFile, Box<dyn Error>> {
|
||||
let f = NamedTempFile::new()?;
|
||||
let map = serde_json::json!({
|
||||
"test.example.org": "b75e9126-d73a-4ae0-9a0d-63cb3552e6cd",
|
||||
});
|
||||
serde_json::to_writer(&f, &map)?;
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
fn gen_config(machine_ids: &Path, host_key: &Path) -> Configuration {
|
||||
let mut config = Configuration {
|
||||
machine_ids: machine_ids.to_str().unwrap().into(),
|
||||
..Default::default()
|
||||
};
|
||||
config.ca.host.private_key_file = host_key.to_str().unwrap().into();
|
||||
config
|
||||
}
|
||||
|
||||
fn gen_ca_key() -> Result<(NamedTempFile, PublicKey), Box<dyn Error>> {
|
||||
let key = PrivateKey::random(&mut OsRng, Algorithm::Ed25519)?;
|
||||
let mut f = NamedTempFile::new()?;
|
||||
f.write_all(key.to_openssh(Default::default())?.as_bytes())?;
|
||||
Ok((f, key.public_key().clone()))
|
||||
}
|
||||
|
||||
pub async fn setup() -> Result<(TestContext, Configuration), Box<dyn Error>> {
|
||||
INIT.call_once(|| {
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.with_env_filter(EnvFilter::from("sshca=trace"))
|
||||
.with_test_writer()
|
||||
.init();
|
||||
});
|
||||
|
||||
let machine_ids = gen_machine_ids()?;
|
||||
let (host_key, host_key_pub) = gen_ca_key()?;
|
||||
let config = gen_config(machine_ids.path(), host_key.path());
|
||||
|
||||
let ctx = TestContext {
|
||||
machine_ids,
|
||||
host_key,
|
||||
host_key_pub,
|
||||
};
|
||||
Ok((ctx, config))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct TestContext {
|
||||
machine_ids: NamedTempFile,
|
||||
host_key: NamedTempFile,
|
||||
host_key_pub: PublicKey,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
pub fn host_ca_fingerprint(&self) -> Fingerprint {
|
||||
self.host_key_pub.fingerprint(Default::default())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
use std::time;
|
||||
|
||||
use argon2::Argon2;
|
||||
use jsonwebtoken::{encode, EncodingKey};
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TestClaims {
|
||||
sub: String,
|
||||
iss: String,
|
||||
aud: String,
|
||||
iat: u64,
|
||||
nbf: u64,
|
||||
exp: u64,
|
||||
}
|
||||
|
||||
pub fn make_token(hostname: &str, machine_id: Uuid) -> String {
|
||||
let now = time::SystemTime::now()
|
||||
.duration_since(time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let claims = TestClaims {
|
||||
sub: hostname.into(),
|
||||
iss: hostname.into(),
|
||||
aud: "sshca.example.org".into(),
|
||||
nbf: now - 60,
|
||||
iat: now,
|
||||
exp: now + 60,
|
||||
};
|
||||
let mut secret = [0u8; 32];
|
||||
Argon2::default()
|
||||
.hash_password_into(
|
||||
machine_id.as_bytes(),
|
||||
hostname.as_bytes(),
|
||||
&mut secret,
|
||||
)
|
||||
.unwrap();
|
||||
let key = EncodingKey::from_secret(&secret);
|
||||
encode(&Default::default(), &claims, &key).unwrap()
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
mod common;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, StatusCode};
|
||||
use form_data_builder::FormData;
|
||||
use ssh_key::{Algorithm, Certificate};
|
||||
use tower::ServiceExt;
|
||||
use uuid::uuid;
|
||||
|
||||
use sshca::server::make_app;
|
||||
|
||||
use common::setup;
|
||||
use common::token;
|
||||
|
||||
const ED25519_KEY: &str = concat!(
|
||||
"ssh-ed25519 ",
|
||||
"AAAAC3NzaC1lZDI1NTE5AAAAIAsFEmrNIoRHHUayEO0NdAIgtMvci/wME07h+A5XSNJy",
|
||||
);
|
||||
const DSA_KEY: &str = concat!(
|
||||
"ssh-dss ",
|
||||
"AAAAB3NzaC1kc3MAAACBALNAS+fZjaWt4q+MAgjf6HREFoYjgoSVJUUCtNmRGhND85msVtla",
|
||||
"kll2gLzL6n6TWyiToARlThoTFu1ZDoGYauDL7iDXrGB6VWJEOQZ3TEMHLFYPziW02AbjR9GI",
|
||||
"ptsF42D0bTTvvaIaBIhOTjWAUjuFIhAKhPkcj+udIcyH8CG1AAAAFQCpbXQSlxOvd4J92j2C",
|
||||
"rWDYVGoK8wAAAIAfHiV6/glGZrDRztJmw1hfwbmiNPxaoSGkB+Necfkj0fZrlyLj8sLJIbGQ",
|
||||
"w0dJMATZdRHw3Ql4R5IOu7sBfX1KQW++onT4ads/Xtl6vwfsjO2e/a6Y1ib9JCIOGJxNAAUC",
|
||||
"JU0Fm0TSv2Nn6UTICAarp1eKALimqkvy1+ygBWjprgAAAIEAic5EpZH9wpgzvl9kPW531yrz",
|
||||
"IOlCcXsJFPqQxUThrB2o1g3Rjpscd9kCw5UlPu6GGLk4aSN3UxeIKymTuKiEi7tvP1Tj/Bv5",
|
||||
"tEc4rhfmrBAfAST09oRFDsELufsOAlTrJ0uk2LhtN14H1RBv9qPR5PQKTEYslyvXG1f8itNQ",
|
||||
"YnQ="
|
||||
);
|
||||
|
||||
fn make_test_request_body(key: &[u8], name: &str) -> (Body, String) {
|
||||
let mut form = FormData::new(Vec::new());
|
||||
form.write_file(
|
||||
"pubkey",
|
||||
key,
|
||||
Some(name.as_ref()),
|
||||
"application/octet-stream",
|
||||
)
|
||||
.unwrap();
|
||||
let content_type = form.content_type_header();
|
||||
let body = Body::from(form.finish().unwrap());
|
||||
(body, content_type)
|
||||
}
|
||||
|
||||
fn make_test_request(body: Body, content_type: &str) -> Request<Body> {
|
||||
let hostname = "test.example.org";
|
||||
let machine_id = uuid!("b75e9126-d73a-4ae0-9a0d-63cb3552e6cd");
|
||||
let token = token::make_token(hostname, machine_id);
|
||||
Request::builder()
|
||||
.uri("/host/sign")
|
||||
.method("POST")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header("Host", "sshca.example.org")
|
||||
.header("Content-Type", content_type)
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign() {
|
||||
let (ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let app = make_app(config);
|
||||
let (body, content_type) = make_test_request_body(
|
||||
ED25519_KEY.as_bytes(),
|
||||
"ssh_host_ed25519_key.pub",
|
||||
);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
let cert = Certificate::from_openssh(std::str::from_utf8(&body).unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(cert.algorithm(), Algorithm::Ed25519);
|
||||
cert.validate(&[ctx.host_ca_fingerprint()]).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_invalid() {
|
||||
let (_ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let (body, content_type) = make_test_request_body(
|
||||
"this is not a valid openssh key".as_bytes(),
|
||||
"ssh_host_ecdsa_key.pub",
|
||||
);
|
||||
let app = make_app(config);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(
|
||||
body,
|
||||
concat!(
|
||||
"Could not parse SSH key: ",
|
||||
"Base64 encoding error: invalid Base64 encoding",
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_nokey() {
|
||||
let (_ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let mut form = FormData::new(Vec::new());
|
||||
let content_type = form.content_type_header();
|
||||
let body = Body::from(form.finish().unwrap());
|
||||
let app = make_app(config);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "No SSH public key provided in request",);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_mangled() {
|
||||
let (_ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let app = make_app(config);
|
||||
let mut form = FormData::new(Vec::new());
|
||||
form.write_file(
|
||||
"pubkey",
|
||||
ED25519_KEY.as_bytes(),
|
||||
Some("ssh_host_ed25519_key.pub".as_ref()),
|
||||
"application/octet-stream",
|
||||
)
|
||||
.unwrap();
|
||||
let content_type = form.content_type_header();
|
||||
let mut form_bytes = form.finish().unwrap();
|
||||
form_bytes.truncate(19);
|
||||
let body = Body::from(form_bytes);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "Error parsing `multipart/form-data` request",);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_bad_request() {
|
||||
let (_ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let app = make_app(config);
|
||||
let content_type = "text/plain";
|
||||
let body = Body::from("test");
|
||||
let req = make_test_request(body, content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "Invalid `boundary` for `multipart/form-data` request",);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_dsa() {
|
||||
let (_ctx, config) = setup::setup().await.unwrap();
|
||||
|
||||
let app = make_app(config);
|
||||
let (body, content_type) =
|
||||
make_test_request_body(DSA_KEY.as_bytes(), "ssh_host_dsa_key.pub");
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "Unsupported key algorithm: ssh-dss");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_failure() {
|
||||
let (_ctx, mut config) = setup::setup().await.unwrap();
|
||||
config.ca.host.private_key_file = "bogus".into();
|
||||
|
||||
let app = make_app(config);
|
||||
let (body, content_type) = make_test_request_body(
|
||||
ED25519_KEY.as_bytes(),
|
||||
"ssh_host_ed25519_key.pub",
|
||||
);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "Service Unavailable");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_unauthorized() {
|
||||
// Deliberately drop the TestContext so the machine ID file gets deleted,
|
||||
// which will cause authentication to fail.
|
||||
let (_, config) = setup::setup().await.unwrap();
|
||||
|
||||
let app = make_app(config);
|
||||
let (body, content_type) = make_test_request_body(
|
||||
ED25519_KEY.as_bytes(),
|
||||
"ssh_host_ed25519_key.pub",
|
||||
);
|
||||
let req = make_test_request(body, &content_type);
|
||||
let res = app.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
assert_eq!(body, "Unauthorized");
|
||||
}
|
Loading…
Reference in New Issue