diff --git a/src/config.rs b/src/config.rs index c69b868..6e4b089 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,5 @@ //! Application configuration +use std::collections::HashMap; use std::io::ErrorKind; use std::path::{Path, PathBuf}; @@ -79,6 +80,10 @@ pub struct UserCaConfig { /// Certificate extensions #[serde(default = "default_user_cert_extensions")] pub extensions: Vec, + + /// Additional principals to add based on user's group membership + #[serde(default)] + pub group_principals: HashMap> } impl Default for UserCaConfig { @@ -88,6 +93,7 @@ impl Default for UserCaConfig { private_key_passphrase_file: None, cert_duration: default_user_cert_duration(), extensions: default_user_cert_extensions(), + group_principals: Default::default(), } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index db8c7e0..edd7bf2 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,5 +1,6 @@ mod error; mod host; +mod oidc; mod user; use std::collections::HashMap; diff --git a/src/server/oidc.rs b/src/server/oidc.rs new file mode 100644 index 0000000..36f6ed1 --- /dev/null +++ b/src/server/oidc.rs @@ -0,0 +1,55 @@ +use openidconnect::core::*; +use openidconnect::*; +use serde::{Deserialize, Serialize}; + +pub type IdTokenFields = openidconnect::IdTokenFields< + AdditionalClaims, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, +>; + +pub type IdToken = openidconnect::IdToken< + AdditionalClaims, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, +>; + +pub type IdTokenClaims = + openidconnect::IdTokenClaims; + +pub type TokenResponse = StandardTokenResponse; + +pub type Client = openidconnect::Client< + AdditionalClaims, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + TokenResponse, + CoreTokenType, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, +>; + +#[derive(Serialize, Deserialize, Debug)] +pub struct AdditionalClaims { + groups: Vec, +} + +impl AdditionalClaims { + pub fn groups(&self) -> &Vec { + &self.groups + } +} +impl openidconnect::AdditionalClaims for AdditionalClaims {} diff --git a/src/server/user.rs b/src/server/user.rs index 92df14f..058bde1 100644 --- a/src/server/user.rs +++ b/src/server/user.rs @@ -12,8 +12,7 @@ use axum::headers::Authorization; use axum::http::request::Parts; use axum::Json; use axum::{RequestPartsExt, TypedHeader}; -use openidconnect::core::{CoreClient, CoreProviderMetadata}; -use openidconnect::core::{CoreIdToken, CoreIdTokenClaims}; +use openidconnect::core::CoreProviderMetadata; use openidconnect::reqwest::async_http_client; use openidconnect::IssuerUrl; use openidconnect::Nonce; @@ -24,6 +23,7 @@ use tracing::{debug, error, info, trace, warn}; use super::error::SignKeyError; use super::{AuthError, Context}; +use super::oidc; use crate::ca; /// Response type for GET /user/openid-config @@ -38,7 +38,7 @@ pub struct OidcConfigResponse { } /// OpenID Connect ID token claims -pub struct Claims(CoreIdTokenClaims); +pub struct Claims(oidc::IdTokenClaims); /// Axum request extractor for OIDC ID tokens in Authorization headers /// @@ -69,7 +69,7 @@ impl FromRequestParts> for Claims { AuthError })?; - let token = CoreIdToken::from_str(bearer.token()).map_err(|e| { + let token = oidc::IdToken::from_str(bearer.token()).map_err(|e| { debug!("Failed to parse OIDC ID token: {}", e); AuthError })?; @@ -77,7 +77,7 @@ impl FromRequestParts> for Claims { let client_id = &oidc_config.client_id; let client_secret = &oidc_config.client_secret; let provider_metadata = get_metadata(ctx).await.ok_or(AuthError)?; - let client = CoreClient::from_provider_metadata( + let client = oidc::Client::from_provider_metadata( provider_metadata, ClientId::new(client_id.into()), client_secret.as_ref().map(|s| ClientSecret::new(s.into())), @@ -169,6 +169,16 @@ pub(super) async fn sign_user_cert( let config = &ctx.config; let duration = Duration::from_secs(config.ca.user.cert_duration); let extensions = &config.ca.user.extensions; + + for group in claims.additional_claims().groups() { + if let Some(principals) = config.ca.user.group_principals.get(group) { + debug!("Adding principals from group {}", group); + for p in principals { + alias.push(p.as_str()) + } + } + } + let privkey = ca::load_private_key( &config.ca.user.private_key_file, config.ca.user.private_key_passphrase_file.as_ref(),