auth: Implement OpenID Connect login flow
This commit adds two path operations, *GET /login* and *GET /oidc-callback*, which initiate and complete the OpenID connect login flow, respectively. Only the *Authorization Code* flow is supported, since this is the only flow implemented by Authelia. There is quite a bit of boilerplate required to fully implement an OIDC relying party, especially in Rust. The documentation for `openidconnect` is decent, but it still took quite a bit of trial and error to get everything working. After successfully finishing the OIDC login, the client will receive a cookie containing a JWT that can be used for further communication with the server. We're not using the OIDC tokens themselves for authorization. For development and testing, Dex is a simple and convenient OIDC IdP. The only caveat is its configuration file must contain list the TCP port clients will use to connect to it, meaning we cannot use Podman dynamic port allocation like we do for Meilisearch. Ultimately, this just means the integration tests will fail if there is another process already listening on 5556.
This commit is contained in:
1177
Cargo.lock
generated
1177
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,8 +9,10 @@ html5ever = "0.27.0"
|
||||
jsonwebtoken = { version = "9.3.1", default-features = false }
|
||||
markup5ever_rcdom = "0.3.0"
|
||||
meilisearch-sdk = "0.28.0"
|
||||
openidconnect = { version = "4.0.0", default-features = false, features = ["reqwest", "native-tls"] }
|
||||
rand = "0.9.0"
|
||||
rocket = { version = "0.5.1", features = ["json"] }
|
||||
reqwest = { version = "0.12.15", features = ["json", "native-tls"] }
|
||||
rocket = { version = "0.5.1", features = ["json", "secrets"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
thiserror = "2.0.12"
|
||||
tracing = "0.1.41"
|
||||
@@ -18,3 +20,5 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
form_urlencoded = "1.2.1"
|
||||
reqwest = { version = "0.12.15", features = ["cookies"] }
|
||||
scraper = { version = "0.23.1", default-features = false }
|
||||
|
||||
29
dex.yaml
Normal file
29
dex.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
issuer: http://127.0.0.1:5556/dex
|
||||
|
||||
storage:
|
||||
type: memory
|
||||
|
||||
web:
|
||||
http: 0.0.0.0:5556
|
||||
|
||||
staticClients:
|
||||
- id: example-app
|
||||
redirectURIs:
|
||||
- 'http://localhost:8000/oidc-callback'
|
||||
name: 'Example App'
|
||||
secret: example-app-secret
|
||||
|
||||
enablePasswordDB: true
|
||||
|
||||
staticPasswords:
|
||||
- email: "user@example.com"
|
||||
# bcrypt hash of "password"
|
||||
hash: "$2y$10$BvrNnoSifaDAcS95zT5CyegMKE90S9gpiMHcJj82hVnyRYwn2LNwS"
|
||||
username: "testuser"
|
||||
userID: "1234"
|
||||
|
||||
oauth2:
|
||||
skipApprovalScreen: true
|
||||
|
||||
logger:
|
||||
level: debug
|
||||
27
run-tests.sh
27
run-tests.sh
@@ -11,6 +11,11 @@ cleanup() {
|
||||
rm "${meilisearch_key}"
|
||||
fi
|
||||
unset meilisearch_key
|
||||
if [ -n "${dex_cid}" ]; then
|
||||
echo 'Stopping Dex ...' >&2
|
||||
podman stop "${dex_cid}" >/dev/null
|
||||
fi
|
||||
unset cid
|
||||
if [ -f "${jwt_key}" ]; then
|
||||
rm "${jwt_key}"
|
||||
fi
|
||||
@@ -48,6 +53,22 @@ curl -s 127.0.0.1:${port}/keys -H "Authorization: Bearer ${MEILI_MASTER_KEY}" \
|
||||
| jq -r '.results[] | select(.name == "Default Admin API Key") | .key' \
|
||||
> "${meilisearch_key}"
|
||||
|
||||
echo 'Starting Dex ...'
|
||||
dex_cid=$(
|
||||
podman run --rm -d \
|
||||
-p 5556:5556 \
|
||||
-v ./dex.yaml:/etc/dex/config.yaml:ro,z \
|
||||
ghcr.io/dexidp/dex \
|
||||
dex serve /etc/dex/config.yaml
|
||||
) || exit
|
||||
dex_port=5556
|
||||
printf 'Waiting for Dex on port %d ...' "${dex_port}" >&2
|
||||
until curl -fs -o /dev/null 127.0.0.1:${dex_port}/dex/healthz; do
|
||||
printf '.'
|
||||
sleep 0.25
|
||||
done
|
||||
echo ' ready' >&2
|
||||
|
||||
echo 'Generating JWT secret key ...' >&2
|
||||
jwt_key=$(mktemp jwt.secret.XXXXXX)
|
||||
tr -cd 0-9A-Za-z_- < /dev/urandom | head -c 32 > "${jwt_key}"
|
||||
@@ -60,6 +81,12 @@ token = "${meilisearch_key}"
|
||||
|
||||
[test.auth]
|
||||
jwt_secret = "${jwt_key}"
|
||||
|
||||
[test.oidc]
|
||||
discovery_url = "http://127.0.0.1:${dex_port}/dex"
|
||||
client_id = "example-app"
|
||||
client_secret = "example-app-secret"
|
||||
callback_url = "http://localhost:8000/oidc-callback"
|
||||
EOF
|
||||
export ROCKET_PROFILE=test
|
||||
export ROCKET_CONFIG
|
||||
|
||||
303
src/auth.rs
303
src/auth.rs
@@ -1,16 +1,31 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use jsonwebtoken::{
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData,
|
||||
Validation,
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
|
||||
};
|
||||
use rocket::http::Status;
|
||||
use openidconnect::core::{
|
||||
CoreAuthenticationFlow, CoreClient, CoreIdTokenClaims,
|
||||
CoreProviderMetadata,
|
||||
};
|
||||
use openidconnect::{
|
||||
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||
EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
|
||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
|
||||
Scope, TokenResponse,
|
||||
};
|
||||
use rocket::config::Config as RocketConfig;
|
||||
use rocket::http::{Cookie, CookieJar, SameSite, Status};
|
||||
use rocket::request::{FromRequest, Outcome, Request};
|
||||
use rocket::response::Redirect;
|
||||
use rocket::time::{Duration, OffsetDateTime};
|
||||
use rocket::State;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::error;
|
||||
use tracing::{debug, error, trace, warn};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::error::{LoginError, OidcError};
|
||||
use crate::Context;
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
@@ -31,7 +46,7 @@ impl<'r> FromRequest<'r> for UserClaims {
|
||||
let secret = &req.rocket().state::<Context>().unwrap().jwt_secret;
|
||||
match req.headers().get_one("Authorization") {
|
||||
Some(s) => match extract_token(secret, s) {
|
||||
Ok(d) => Outcome::Success(d.claims),
|
||||
Ok(d) => Outcome::Success(d),
|
||||
Err(e) => Outcome::Error((Status::Unauthorized, e)),
|
||||
},
|
||||
None => Outcome::Error((Status::Unauthorized, "Unauthorized")),
|
||||
@@ -62,15 +77,89 @@ impl UserClaims {
|
||||
let k = EncodingKey::from_secret(secret);
|
||||
encode(&Header::default(), self, &k)
|
||||
}
|
||||
|
||||
fn from_jwt(
|
||||
token: &str,
|
||||
secret: &[u8],
|
||||
) -> Result<Self, jsonwebtoken::errors::Error> {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
v.validate_nbf = true;
|
||||
v.set_issuer(&[env!("CARGO_PKG_NAME")]);
|
||||
v.set_audience(&[env!("CARGO_PKG_NAME")]);
|
||||
let k = DecodingKey::from_secret(secret);
|
||||
Ok(decode(token, &k, &v)?.claims)
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Create a signed JWT for a [`UserClaims`]
|
||||
pub fn make_jwt(
|
||||
&self,
|
||||
claims: &UserClaims,
|
||||
) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
claims.to_jwt(&self.jwt_secret)
|
||||
}
|
||||
|
||||
/// Decode a [`UserClaims`] from a JWT
|
||||
pub fn decode_jwt(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Result<UserClaims, jsonwebtoken::errors::Error> {
|
||||
UserClaims::from_jwt(token, &self.jwt_secret)
|
||||
}
|
||||
}
|
||||
|
||||
pub type OidcClient = CoreClient<
|
||||
EndpointSet,
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointMaybeSet,
|
||||
EndpointMaybeSet,
|
||||
>;
|
||||
|
||||
/// Represents the OIDC login state, which is stored in a cookie
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OidcClientState {
|
||||
pkce: String,
|
||||
csrf: String,
|
||||
nonce: String,
|
||||
}
|
||||
|
||||
/// Error response returned to the client when OIDC login fails
|
||||
#[derive(Debug, rocket::Responder)]
|
||||
#[response(status = 401)]
|
||||
pub struct LoginFailed(String);
|
||||
|
||||
impl From<LoginError> for LoginFailed {
|
||||
fn from(e: LoginError) -> Self {
|
||||
error!("Login failed: {}", e);
|
||||
Self(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an openidconnect Client instance from the given Config
|
||||
pub async fn get_oidc_client(
|
||||
config: &Config,
|
||||
) -> Result<OidcClient, OidcError> {
|
||||
let http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.unwrap();
|
||||
let oidc_url = config.oidc.discovery_url.clone();
|
||||
let client_id = config.oidc.client_id.clone();
|
||||
let client_secret = config.oidc.client_secret.clone();
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(oidc_url)?,
|
||||
&http_client,
|
||||
)
|
||||
.await?;
|
||||
Ok(CoreClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(client_id),
|
||||
client_secret.map(ClientSecret::new),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(config.oidc.callback_url.clone())?))
|
||||
}
|
||||
|
||||
/// Load the JWT secret from the path specified in the configuration
|
||||
@@ -81,10 +170,11 @@ pub fn load_secret(config: &Config) -> Result<Vec<u8>, std::io::Error> {
|
||||
Ok(secret)
|
||||
}
|
||||
|
||||
/// Extract a JWT from the Authorization HTTP request header
|
||||
fn extract_token(
|
||||
secret: &[u8],
|
||||
header: &str,
|
||||
) -> Result<TokenData<UserClaims>, &'static str> {
|
||||
) -> Result<UserClaims, &'static str> {
|
||||
let mut iter = header.split_ascii_whitespace();
|
||||
let scheme = iter.next();
|
||||
let token = iter.next();
|
||||
@@ -92,23 +182,188 @@ fn extract_token(
|
||||
return Err("Unsupported authorization scheme");
|
||||
}
|
||||
if let Some(token) = token {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
v.validate_nbf = true;
|
||||
v.set_issuer(&[env!("CARGO_PKG_NAME")]);
|
||||
v.set_audience(&[env!("CARGO_PKG_NAME")]);
|
||||
let k = DecodingKey::from_secret(secret);
|
||||
match decode(token, &k, &v) {
|
||||
Ok(d) => Ok(d),
|
||||
Err(e) => {
|
||||
error!("Failed to decode auth token: {}", e);
|
||||
Err("Invalid token")
|
||||
},
|
||||
}
|
||||
Ok(UserClaims::from_jwt(token, secret).map_err(|e| {
|
||||
error!("Failed to decode auth token: {}", e);
|
||||
"Invalid token"
|
||||
})?)
|
||||
} else {
|
||||
Err("Invalid Authorization header")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the authorization code from OIDC callback parameters
|
||||
///
|
||||
/// If the callback indicates an authorization failure, or the OAuth2
|
||||
/// state is invalid, a [`LoginError`] is returned.
|
||||
fn get_auth_code(
|
||||
csrf_token: &str,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<String, LoginError> {
|
||||
let state = params.get("state").ok_or(LoginError::MissingStateParam)?;
|
||||
if state != csrf_token {
|
||||
return Err(LoginError::InvalidCsrfState);
|
||||
}
|
||||
if let Some(error) = params.get("error") {
|
||||
let msg = if let Some(err_desc) = params.get("error_description") {
|
||||
format!("Error handling OIDC callback ({}): {}", error, err_desc)
|
||||
} else {
|
||||
format!("Error handling OIDC callback: {}", error)
|
||||
};
|
||||
error!("{}", msg);
|
||||
return Err(LoginError::IdpError(msg));
|
||||
}
|
||||
Ok(params
|
||||
.get("code")
|
||||
.ok_or(LoginError::MissingAuthCode)?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Exchange an OIDC Authorization code for an Identity Token
|
||||
async fn exchange_code(
|
||||
code: String,
|
||||
state: OidcClientState,
|
||||
ctx: &Context,
|
||||
) -> Result<CoreIdTokenClaims, LoginError> {
|
||||
let pkce_verifier = PkceCodeVerifier::new(state.pkce);
|
||||
debug!("Exchanging authorization code for access token");
|
||||
let http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.unwrap();
|
||||
let token_response = ctx
|
||||
.oidc
|
||||
.exchange_code(AuthorizationCode::new(code))?
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(&http_client)
|
||||
.await
|
||||
.map_err(LoginError::TokenRequestError)?;
|
||||
debug!(
|
||||
"Received response token type {:?}",
|
||||
token_response.token_type()
|
||||
);
|
||||
debug!("Access token: {}", token_response.access_token().secret());
|
||||
trace!("Token response: {:?}", token_response);
|
||||
|
||||
let id_token = token_response
|
||||
.id_token()
|
||||
.ok_or(LoginError::MissingIdToken)?;
|
||||
let id_token_verifier = ctx.oidc.id_token_verifier();
|
||||
let claims = id_token
|
||||
.claims(&ctx.oidc.id_token_verifier(), &Nonce::new(state.nonce))?;
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
id_token.signing_alg()?,
|
||||
id_token.signing_key(&id_token_verifier)?,
|
||||
)?;
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
return Err(LoginError::InvalidAccessToken);
|
||||
}
|
||||
}
|
||||
Ok(claims.clone())
|
||||
}
|
||||
|
||||
/// Initiate OpenID Connect login flow
|
||||
#[rocket::get("/login")]
|
||||
pub async fn oidc_login(
|
||||
ctx: &State<Context>,
|
||||
cookies: &CookieJar<'_>,
|
||||
config: &RocketConfig,
|
||||
) -> Redirect {
|
||||
let (pkce_challenge, pkce_verifier) =
|
||||
PkceCodeChallenge::new_random_sha256();
|
||||
trace!("PKCE: {:?} {:?}", pkce_challenge, pkce_verifier);
|
||||
|
||||
let (auth_url, csrf_token, nonce) = ctx
|
||||
.oidc
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.add_scope(Scope::new("openid".into()))
|
||||
.add_scope(Scope::new("profile".into()))
|
||||
.add_scope(Scope::new("email".into()))
|
||||
.add_scope(Scope::new("groups".into()))
|
||||
.url();
|
||||
trace!(
|
||||
"CSRF token: {}, nonce: {}",
|
||||
csrf_token.secret(),
|
||||
nonce.secret()
|
||||
);
|
||||
let state = OidcClientState {
|
||||
pkce: pkce_verifier.into_secret(),
|
||||
csrf: csrf_token.into_secret(),
|
||||
nonce: nonce.secret().to_string(),
|
||||
};
|
||||
match rocket::serde::json::to_string(&state) {
|
||||
Ok(s) => {
|
||||
cookies.add_private(
|
||||
Cookie::build(("oidc.state", s))
|
||||
.secure(config.profile != RocketConfig::DEBUG_PROFILE)
|
||||
.same_site(
|
||||
if config.profile == RocketConfig::DEBUG_PROFILE {
|
||||
SameSite::Lax
|
||||
} else {
|
||||
SameSite::Strict
|
||||
},
|
||||
)
|
||||
.http_only(true)
|
||||
.expires(None)
|
||||
.build(),
|
||||
);
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to serialize OIDC client state: {}", e);
|
||||
},
|
||||
};
|
||||
Redirect::to(auth_url.to_string())
|
||||
}
|
||||
|
||||
/// Handle OpenID Connect authorization callback
|
||||
#[rocket::get("/oidc-callback?<params..>")]
|
||||
pub async fn oidc_callback(
|
||||
params: HashMap<String, String>,
|
||||
cookies: &CookieJar<'_>,
|
||||
ctx: &State<Context>,
|
||||
config: &State<Config>,
|
||||
) -> Result<Redirect, LoginFailed> {
|
||||
trace!("{:?}", params);
|
||||
let state = cookies
|
||||
.get_private("oidc.state")
|
||||
.inspect(|c| cookies.remove(c.clone()))
|
||||
.ok_or(LoginError::MissingClientState)?;
|
||||
let state: OidcClientState = rocket::serde::json::from_str(state.value())
|
||||
.map_err(|e| {
|
||||
error!("Failed to parse OIDC client state: {}", e);
|
||||
LoginError::InvalidClientState
|
||||
})?;
|
||||
let code = get_auth_code(&state.csrf, ¶ms)?;
|
||||
trace!("Got authorization code: {}", code);
|
||||
let claims = exchange_code(code, state, ctx).await?;
|
||||
let user =
|
||||
UserClaims::new(claims.subject().as_str(), config.auth.login_ttl);
|
||||
let expires = OffsetDateTime::now_utc()
|
||||
+ Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else(
|
||||
|e| {
|
||||
warn!("Invalid login TTL value: {}", e);
|
||||
i64::MAX
|
||||
},
|
||||
));
|
||||
cookies.add(
|
||||
Cookie::build((
|
||||
"auth.token",
|
||||
ctx.make_jwt(&user).map_err(LoginError::from)?,
|
||||
))
|
||||
.secure(true)
|
||||
.http_only(true)
|
||||
.expires(expires)
|
||||
.build(),
|
||||
);
|
||||
Ok(Redirect::to("/"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -139,8 +394,8 @@ mod test {
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let data = extract_token(SECRET, &header).unwrap();
|
||||
assert_eq!(claims, data.claims);
|
||||
let claims = extract_token(SECRET, &header).unwrap();
|
||||
assert_eq!(claims, claims);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -233,4 +488,12 @@ mod test {
|
||||
let err = extract_token(SECRET, header).unwrap_err();
|
||||
assert_eq!(err, "Unsupported authorization scheme");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_userclaims_jwt_round_trip() {
|
||||
let claims = UserClaims::new("test1", 60);
|
||||
let jwt = claims.to_jwt(SECRET).unwrap();
|
||||
let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap();
|
||||
assert_eq!(claims, claims1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,29 @@ pub struct MeilisearchConfig {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthConfig {
|
||||
pub jwt_secret: PathBuf,
|
||||
#[serde(default = "default_login_ttl")]
|
||||
pub login_ttl: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcConfig {
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub callback_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
pub auth: AuthConfig,
|
||||
pub oidc: OidcConfig,
|
||||
pub meilisearch: MeilisearchConfig,
|
||||
}
|
||||
|
||||
fn default_index() -> String {
|
||||
env!("CARGO_PKG_NAME").into()
|
||||
}
|
||||
|
||||
fn default_login_ttl() -> u64 {
|
||||
604800
|
||||
}
|
||||
|
||||
59
src/error.rs
59
src/error.rs
@@ -1,7 +1,66 @@
|
||||
use openidconnect::core::CoreErrorResponseType;
|
||||
use openidconnect::url::ParseError;
|
||||
use openidconnect::{
|
||||
ClaimsVerificationError, ConfigurationError, DiscoveryError,
|
||||
HttpClientError, RequestTokenError, SignatureVerificationError,
|
||||
SigningError, StandardErrorResponse,
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OidcError {
|
||||
#[error("Invalid URL {0}")]
|
||||
Url(#[from] ParseError),
|
||||
#[error("OIDC Discovery failed: {0}")]
|
||||
Discovery(#[from] DiscoveryError<HttpClientError<reqwest::Error>>),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LoginError {
|
||||
#[error("Missing state parameter")]
|
||||
MissingStateParam,
|
||||
#[error("Invalid state parameter")]
|
||||
InvalidCsrfState,
|
||||
#[error("Missing OIDC client state cookie")]
|
||||
MissingClientState,
|
||||
#[error("Invalid OIDC client state cookie")]
|
||||
InvalidClientState,
|
||||
#[error("{0}")]
|
||||
IdpError(String),
|
||||
#[error("Missing OAuth2 authorization code")]
|
||||
MissingAuthCode,
|
||||
#[error("OIDC configuration error: {0}")]
|
||||
ConfigurationError(#[from] ConfigurationError),
|
||||
#[error("Token request error: {0}")]
|
||||
TokenRequestError(
|
||||
RequestTokenError<
|
||||
HttpClientError<reqwest::Error>,
|
||||
StandardErrorResponse<CoreErrorResponseType>,
|
||||
>,
|
||||
),
|
||||
#[error("Server did not return an ID token")]
|
||||
MissingIdToken,
|
||||
#[error("Server returned an invalid ID token")]
|
||||
InvalidIdToken,
|
||||
#[error("Invalid token claims: {0}")]
|
||||
ClaimsVerificationError(#[from] ClaimsVerificationError),
|
||||
#[error("Signature verification error: {0}")]
|
||||
SignatureVerificationError(#[from] SignatureVerificationError),
|
||||
#[error("Token signature error: {0}")]
|
||||
SigningError(#[from] SigningError),
|
||||
#[error("Invalid access token")]
|
||||
InvalidAccessToken,
|
||||
#[error("JWT serialization error: {0}")]
|
||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum InitError {
|
||||
#[error("Meilisearch error: {0}")]
|
||||
Meilisearch(#[from] crate::meilisearch::Error),
|
||||
|
||||
#[error("Failed to load JWT secret: {0}")]
|
||||
LoadSecret(std::io::Error),
|
||||
|
||||
#[error("Invalid OIDC configuration: {0}")]
|
||||
Oidc(#[from] OidcError),
|
||||
}
|
||||
|
||||
28
src/lib.rs
28
src/lib.rs
@@ -15,28 +15,39 @@ pub use error::InitError;
|
||||
pub struct Context {
|
||||
client: MeilisearchClient,
|
||||
jwt_secret: Vec<u8>,
|
||||
oidc: auth::OidcClient,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn init(config: &Config) -> Result<Self, InitError> {
|
||||
pub async fn init(config: &Config) -> Result<Self, InitError> {
|
||||
let client = MeilisearchClient::try_from(config)?;
|
||||
let jwt_secret =
|
||||
auth::load_secret(config).map_err(InitError::LoadSecret)?;
|
||||
Ok(Self { client, jwt_secret })
|
||||
let oidc = auth::get_oidc_client(config).await?;
|
||||
Ok(Self {
|
||||
client,
|
||||
jwt_secret,
|
||||
oidc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the application context
|
||||
async fn init_context(rocket: Rocket<rocket::Build>) -> fairing::Result {
|
||||
let config: &Config = rocket.state().unwrap();
|
||||
let ctx = match Context::init(config) {
|
||||
let ctx = match Context::init(config).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
eprintln!("Could not initialize application context: {}", e);
|
||||
return Err(rocket);
|
||||
},
|
||||
};
|
||||
Ok(rocket.manage(ctx))
|
||||
// Use the JWT secret as the Rocket secret key, for encrypted cookies
|
||||
let figment = rocket
|
||||
.figment()
|
||||
.clone()
|
||||
.join(("secret_key", &ctx.jwt_secret));
|
||||
Ok(rocket.configure(figment).manage(ctx))
|
||||
}
|
||||
|
||||
/// Set up Meilisearch
|
||||
@@ -59,7 +70,14 @@ async fn meilisearch_setup(rocket: Rocket<rocket::Build>) -> fairing::Result {
|
||||
|
||||
pub fn rocket() -> Rocket<rocket::Build> {
|
||||
rocket::build()
|
||||
.mount("/", rocket::routes![page::post_page])
|
||||
.mount(
|
||||
"/",
|
||||
rocket::routes![
|
||||
auth::oidc_callback,
|
||||
auth::oidc_login,
|
||||
page::post_page
|
||||
],
|
||||
)
|
||||
.attach(AdHoc::config::<Config>())
|
||||
.attach(AdHoc::try_on_ignite("Initialize context", init_context))
|
||||
.attach(AdHoc::try_on_ignite("Meilisearch Setup", meilisearch_setup))
|
||||
|
||||
84
tests/integration/auth.rs
Normal file
84
tests/integration/auth.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use rocket::http::Status;
|
||||
use rocket::local::asynchronous::Client;
|
||||
use rocket::uri;
|
||||
use scraper::{Html, Selector};
|
||||
use tracing::debug;
|
||||
|
||||
use seensite::auth::*;
|
||||
use seensite::Context;
|
||||
|
||||
#[rocket::async_test]
|
||||
async fn test_login() {
|
||||
super::setup();
|
||||
|
||||
let client = Client::tracked(seensite::rocket()).await.unwrap();
|
||||
let dex = reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.cookie_store(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let ctx: &Context = client.rocket().state().unwrap();
|
||||
|
||||
// First, initiate the login process
|
||||
let req = client.get(uri![oidc_login]);
|
||||
let res = req.dispatch().await;
|
||||
assert_eq!(res.status(), Status::SeeOther);
|
||||
let mut location = res.headers().get_one("Location").unwrap().to_string();
|
||||
|
||||
// Next, follow the redirect URL provided by the login page.
|
||||
// This will redirect again.
|
||||
let res = loop {
|
||||
debug!("Redirect: {}", location);
|
||||
let res = dex.get(location).send().await.unwrap();
|
||||
if res.status() == reqwest::StatusCode::FOUND {
|
||||
let base_url = res.url().clone();
|
||||
location = base_url
|
||||
.join(res.headers().get("Location").unwrap().to_str().unwrap())
|
||||
.unwrap()
|
||||
.to_string();
|
||||
continue;
|
||||
}
|
||||
break res;
|
||||
};
|
||||
|
||||
// After all the redirects, we end up on the IdP login form.
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK);
|
||||
// Obtain the login form target
|
||||
let base_url = res.url().clone();
|
||||
let body = res.text().await;
|
||||
let doc = Html::parse_fragment(&body.unwrap());
|
||||
let sel = Selector::parse("form").unwrap();
|
||||
let form = doc.select(&sel).next().unwrap();
|
||||
let action = form.attr("action").unwrap();
|
||||
let url = base_url.join(action).unwrap();
|
||||
// Post the user credentials to the IdP login form
|
||||
let res = dex
|
||||
.post(url)
|
||||
.form(&[("login", "user@example.com"), ("password", "password")])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), reqwest::StatusCode::SEE_OTHER);
|
||||
|
||||
// The result of the IdP login form submission is another redirect
|
||||
// to the OIDC callback of our application.
|
||||
let location = reqwest::Url::parse(
|
||||
res.headers().get("Location").unwrap().to_str().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let callback =
|
||||
format!("{}?{}", location.path(), location.query().unwrap());
|
||||
debug!("Callback: {}", callback);
|
||||
|
||||
// Finally, make the callback request to finish the login process.
|
||||
let res = client.get(callback).dispatch().await;
|
||||
assert_eq!(res.status(), Status::SeeOther);
|
||||
let location = res.headers().get_one("Location").unwrap().to_string();
|
||||
assert_eq!(location, "/");
|
||||
let cookie = res.cookies().get("auth.token").unwrap();
|
||||
debug!("Cookie: {:?}", cookie);
|
||||
let claims = ctx.decode_jwt(cookie.value()).unwrap();
|
||||
debug!("Claims: {:?}", claims);
|
||||
assert!(!claims.sub.is_empty());
|
||||
}
|
||||
@@ -1 +1,22 @@
|
||||
mod auth;
|
||||
mod page;
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static SETUP: LazyLock<()> = LazyLock::new(|| {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::new(
|
||||
concat!(
|
||||
env!("CARGO_PKG_NAME"),
|
||||
"=trace,",
|
||||
"seensite=trace,",
|
||||
"debug",
|
||||
)
|
||||
))
|
||||
.with_test_writer()
|
||||
.init();
|
||||
});
|
||||
|
||||
fn setup() {
|
||||
LazyLock::force(&SETUP);
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ amet auctor risus lacinia. Duis feugiat lobortis orci quis sagittis.</p>
|
||||
|
||||
#[test]
|
||||
fn test_post_page() {
|
||||
super::setup();
|
||||
let client = Client::tracked(seensite::rocket()).unwrap();
|
||||
let ctx: &Context = client.rocket().state().unwrap();
|
||||
let claims = UserClaims::new("test1", 60);
|
||||
@@ -52,6 +53,7 @@ fn test_post_page() {
|
||||
|
||||
#[test]
|
||||
fn test_post_page_unauth() {
|
||||
super::setup();
|
||||
let client = Client::tracked(seensite::rocket()).unwrap();
|
||||
let data = Serializer::new(String::new())
|
||||
.append_pair("url", TEST_URL)
|
||||
|
||||
Reference in New Issue
Block a user