diff --git a/.gitignore b/.gitignore index 704da5f..c38a49e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /Rocket.toml /jwt.secret /meilisearch.token +/oidc.secret diff --git a/src/auth.rs b/src/auth.rs index 6f28e13..0ab5526 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,3 +1,4 @@ +use std::path::Path; use std::collections::HashMap; use std::io::Read; use std::time::SystemTime; @@ -177,18 +178,15 @@ impl From for LoginFailed { /// Create an openidconnect Client instance from the given Config pub async fn get_oidc_client( - config: &Config, + discovery_url: String, + client_id: String, + client_secret: Option, + callback_url: String, + http_client: &reqwest::Client, ) -> Result { - let http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .unwrap(); - let oidc_url = config.oidc.discovery_url.clone(); - let client_id = config.oidc.client_id.clone(); - let client_secret = config.oidc.client_secret.clone(); let provider_metadata = CoreProviderMetadata::discover_async( - IssuerUrl::new(oidc_url)?, - &http_client, + IssuerUrl::new(discovery_url)?, + http_client, ) .await?; Ok(CoreClient::from_provider_metadata( @@ -196,13 +194,13 @@ pub async fn get_oidc_client( ClientId::new(client_id), client_secret.map(ClientSecret::new), ) - .set_redirect_uri(RedirectUrl::new(config.oidc.callback_url.clone())?)) + .set_redirect_uri(RedirectUrl::new(callback_url)?)) } -/// Load the JWT secret from the path specified in the configuration -pub fn load_secret(config: &Config) -> Result, std::io::Error> { +/// Load a secret from the specified path +pub fn load_secret>(path: P) -> Result, std::io::Error> { let mut secret = vec![]; - let mut f = std::fs::File::open(&config.auth.jwt_secret)?; + let mut f = std::fs::File::open(path)?; f.read_to_end(&mut secret)?; Ok(secret) } @@ -263,15 +261,12 @@ async fn exchange_code( ) -> Result { let pkce_verifier = PkceCodeVerifier::new(state.pkce); debug!("Exchanging authorization code for access token"); - let http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .unwrap(); - let token_response = ctx - .oidc + let http_client = &ctx.oidc_http_client; + let oidc = ctx.oidc().await?; + let token_response = oidc .exchange_code(AuthorizationCode::new(code))? .set_pkce_verifier(pkce_verifier) - .request_async(&http_client) + .request_async(http_client) .await .map_err(LoginError::TokenRequestError)?; debug!( @@ -284,9 +279,9 @@ async fn exchange_code( let id_token = token_response .id_token() .ok_or(LoginError::MissingIdToken)?; - let id_token_verifier = ctx.oidc.id_token_verifier(); + let id_token_verifier = oidc.id_token_verifier(); let claims = id_token - .claims(&ctx.oidc.id_token_verifier(), &Nonce::new(state.nonce))?; + .claims(&oidc.id_token_verifier(), &Nonce::new(state.nonce))?; if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( token_response.access_token(), @@ -306,13 +301,13 @@ pub async fn oidc_login( ctx: &State, cookies: &CookieJar<'_>, config: &RocketConfig, -) -> Redirect { +) -> Result { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); trace!("PKCE: {:?} {:?}", pkce_challenge, pkce_verifier); - let (auth_url, csrf_token, nonce) = ctx - .oidc + let oidc = ctx.oidc().await.map_err(LoginError::from)?; + let (auth_url, csrf_token, nonce) = oidc .authorize_url( CoreAuthenticationFlow::AuthorizationCode, CsrfToken::new_random, @@ -355,7 +350,7 @@ pub async fn oidc_login( error!("Failed to serialize OIDC client state: {}", e); }, }; - Redirect::to(auth_url.to_string()) + Ok(Redirect::to(auth_url.to_string())) } /// Handle OpenID Connect authorization callback diff --git a/src/config.rs b/src/config.rs index 8470bef..4ea47ea 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,7 +18,7 @@ pub struct AuthConfig { pub login_ttl: u64, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OidcConfig { pub discovery_url: String, pub client_id: String, diff --git a/src/error.rs b/src/error.rs index ff56d03..4c516f5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -51,16 +51,18 @@ pub enum LoginError { InvalidAccessToken, #[error("JWT serialization error: {0}")] JwtError(#[from] jsonwebtoken::errors::Error), + #[error("Invalid OIDC configuration: {0}")] + Oidc(#[from] OidcError), } #[derive(Debug, thiserror::Error)] pub enum InitError { #[error("Meilisearch error: {0}")] Meilisearch(#[from] crate::meilisearch::Error), - #[error("Failed to load JWT secret: {0}")] - LoadSecret(std::io::Error), - - #[error("Invalid OIDC configuration: {0}")] - Oidc(#[from] OidcError), + LoadJwtSecret(std::io::Error), + #[error("Failed to load OIDC client secret: {0}")] + LoadOidcSecret(std::io::Error), + #[error("Failed to initialize HTTP client: {0}")] + ReqwestError(#[from] reqwest::Error) } diff --git a/src/lib.rs b/src/lib.rs index 42dd76a..f5b006b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,21 +15,46 @@ pub use error::InitError; pub struct Context { client: MeilisearchClient, jwt_secret: Vec, - oidc: auth::OidcClient, + oidc: config::OidcConfig, + oidc_client_secret: Option, + oidc_http_client: reqwest::Client, } impl Context { pub async fn init(config: &Config) -> Result { let client = MeilisearchClient::try_from(config)?; - let jwt_secret = - auth::load_secret(config).map_err(InitError::LoadSecret)?; - let oidc = auth::get_oidc_client(config).await?; + let jwt_secret = auth::load_secret(&config.auth.jwt_secret) + .map_err(InitError::LoadJwtSecret)?; + let oidc = config.oidc.clone(); + let oidc_client_secret = match &config.oidc.client_secret { + Some(p) => match auth::load_secret(p) { + Ok(s) => Some(String::from_utf8_lossy(&s).trim().to_string()), + Err(e) => return Err(InitError::LoadOidcSecret(e)), + }, + None => None, + }; + let oidc_http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build()?; Ok(Self { client, jwt_secret, oidc, + oidc_client_secret, + oidc_http_client, }) } + + pub async fn oidc(&self) -> Result { + auth::get_oidc_client( + self.oidc.discovery_url.clone(), + self.oidc.client_id.clone(), + self.oidc_client_secret.clone(), + self.oidc.callback_url.clone(), + &self.oidc_http_client, + ) + .await + } } /// Initialize the application context