Compare commits
10 Commits
03fe19aa9b
...
dbd9165626
| Author | SHA1 | Date | |
|---|---|---|---|
| dbd9165626 | |||
| a50dca7fae | |||
| 3d2772cfc8 | |||
| 42502694f3 | |||
| e1da535dc2 | |||
| c227bec62d | |||
| 76cf57ebe0 | |||
| 720bb690ea | |||
| a1308507af | |||
| df560b18f2 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
/target
|
||||
/Rocket.toml
|
||||
/jwt.secret
|
||||
/meilisearch.token
|
||||
/oidc.secret
|
||||
|
||||
1179
Cargo.lock
generated
1179
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
@@ -6,11 +6,19 @@ edition = "2021"
|
||||
[dependencies]
|
||||
chrono = { version = "0.4.40", default-features = false, features = ["std", "clock", "serde"] }
|
||||
html5ever = "0.27.0"
|
||||
jsonwebtoken = { version = "9.3.1", default-features = false }
|
||||
markup5ever_rcdom = "0.3.0"
|
||||
meilisearch-sdk = "0.28.0"
|
||||
openidconnect = { version = "4.0.0", default-features = false, features = ["reqwest", "native-tls"] }
|
||||
rand = "0.9.0"
|
||||
rocket = { version = "0.5.1", features = ["json"] }
|
||||
reqwest = { version = "0.12.15", features = ["json", "native-tls"] }
|
||||
rocket = { version = "0.5.1", features = ["json", "secrets"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
thiserror = "2.0.12"
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
form_urlencoded = "1.2.1"
|
||||
reqwest = { version = "0.12.15", features = ["cookies"] }
|
||||
scraper = { version = "0.23.1", default-features = false }
|
||||
|
||||
9
Justfile
Normal file
9
Justfile
Normal file
@@ -0,0 +1,9 @@
|
||||
# vim: set et :
|
||||
|
||||
integration-tests:
|
||||
./run-tests.sh --test=integration
|
||||
|
||||
unit-tests:
|
||||
cargo test --lib
|
||||
|
||||
test: unit-tests integration-tests
|
||||
29
dex.yaml
Normal file
29
dex.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
issuer: http://127.0.0.1:5556/dex
|
||||
|
||||
storage:
|
||||
type: memory
|
||||
|
||||
web:
|
||||
http: 0.0.0.0:5556
|
||||
|
||||
staticClients:
|
||||
- id: example-app
|
||||
redirectURIs:
|
||||
- 'http://localhost:8000/oidc-callback'
|
||||
name: 'Example App'
|
||||
secret: example-app-secret
|
||||
|
||||
enablePasswordDB: true
|
||||
|
||||
staticPasswords:
|
||||
- email: "user@example.com"
|
||||
# bcrypt hash of "password"
|
||||
hash: "$2y$10$BvrNnoSifaDAcS95zT5CyegMKE90S9gpiMHcJj82hVnyRYwn2LNwS"
|
||||
username: "testuser"
|
||||
userID: "1234"
|
||||
|
||||
oauth2:
|
||||
skipApprovalScreen: true
|
||||
|
||||
logger:
|
||||
level: debug
|
||||
37
examples/make-token.rs
Normal file
37
examples/make-token.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::io::Read;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UserClaims {
|
||||
aud: String,
|
||||
exp: u64,
|
||||
iat: u64,
|
||||
iss: String,
|
||||
nbf: u64,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<_> = std::env::args().collect();
|
||||
let mut secret = vec![];
|
||||
let mut f = std::fs::File::open(&args[1]).unwrap();
|
||||
f.read_to_end(&mut secret).unwrap();
|
||||
let k = EncodingKey::from_secret(&secret);
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 604800,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: args[2].to_string(),
|
||||
};
|
||||
let jwt = encode(&Header::default(), &claims, &k).unwrap();
|
||||
println!("{}", jwt);
|
||||
}
|
||||
93
run-tests.sh
Executable file
93
run-tests.sh
Executable file
@@ -0,0 +1,93 @@
|
||||
#!/bin/sh
|
||||
# vim: set ts=4 sts=4 sw=4 et :
|
||||
|
||||
cleanup() {
|
||||
if [ -n "${cid}" ]; then
|
||||
echo 'Stopping Meilisearch ...' >&2
|
||||
podman stop "${cid}" >/dev/null
|
||||
fi
|
||||
unset cid
|
||||
if [ -f "${meilisearch_key}" ]; then
|
||||
rm "${meilisearch_key}"
|
||||
fi
|
||||
unset meilisearch_key
|
||||
if [ -n "${dex_cid}" ]; then
|
||||
echo 'Stopping Dex ...' >&2
|
||||
podman stop "${dex_cid}" >/dev/null
|
||||
fi
|
||||
unset cid
|
||||
if [ -f "${jwt_key}" ]; then
|
||||
rm "${jwt_key}"
|
||||
fi
|
||||
unset jwt_key
|
||||
if [ -f "${ROCKET_CONFIG}" ]; then
|
||||
rm "${ROCKET_CONFIG}"
|
||||
fi
|
||||
unset ROCKET_CONFIG
|
||||
}
|
||||
|
||||
trap cleanup INT TERM QUIT EXIT
|
||||
|
||||
echo 'Generating Meilisearch master key ...' >&2
|
||||
MEILI_MASTER_KEY=$(tr -cd 0-9A-Za-z_- < /dev/urandom | head -c 32)
|
||||
|
||||
echo 'Starting Meilisearch ...' >&2
|
||||
cid=$(
|
||||
export MEILI_MASTER_KEY
|
||||
podman run --rm -d -P \
|
||||
-e MEILI_NO_ANALYTICS=true \
|
||||
-e MEILI_MASTER_KEY \
|
||||
docker.io/getmeili/meilisearch:v1.13
|
||||
)
|
||||
|
||||
port=$(podman port ${cid} 7700/tcp | cut -d: -f2)
|
||||
printf 'Waiting for Meilisearch on port %d ...' "${port}" >&2
|
||||
until curl -fs -o /dev/null 127.0.0.1:${port}; do
|
||||
printf '.'
|
||||
sleep 0.25
|
||||
done
|
||||
echo ' ready' >&2
|
||||
|
||||
meilisearch_key=$(mktemp meilisearch.token.XXXXXX)
|
||||
curl -s 127.0.0.1:${port}/keys -H "Authorization: Bearer ${MEILI_MASTER_KEY}" \
|
||||
| jq -r '.results[] | select(.name == "Default Admin API Key") | .key' \
|
||||
> "${meilisearch_key}"
|
||||
|
||||
echo 'Starting Dex ...'
|
||||
dex_cid=$(
|
||||
podman run --rm -d \
|
||||
-p 5556:5556 \
|
||||
-v ./dex.yaml:/etc/dex/config.yaml:ro,z \
|
||||
ghcr.io/dexidp/dex \
|
||||
dex serve /etc/dex/config.yaml
|
||||
) || exit
|
||||
dex_port=5556
|
||||
printf 'Waiting for Dex on port %d ...' "${dex_port}" >&2
|
||||
until curl -fs -o /dev/null 127.0.0.1:${dex_port}/dex/healthz; do
|
||||
printf '.'
|
||||
sleep 0.25
|
||||
done
|
||||
echo ' ready' >&2
|
||||
|
||||
echo 'Generating JWT secret key ...' >&2
|
||||
jwt_key=$(mktemp jwt.secret.XXXXXX)
|
||||
tr -cd 0-9A-Za-z_- < /dev/urandom | head -c 32 > "${jwt_key}"
|
||||
|
||||
ROCKET_CONFIG=$(mktemp Rocket.toml.XXXXXX)
|
||||
cat > "${ROCKET_CONFIG}" <<EOF
|
||||
[test.meilisearch]
|
||||
url = "http://127.0.0.1:${port}"
|
||||
token = "${meilisearch_key}"
|
||||
|
||||
[test.auth]
|
||||
jwt_secret = "${jwt_key}"
|
||||
|
||||
[test.oidc]
|
||||
discovery_url = "http://127.0.0.1:${dex_port}/dex"
|
||||
client_id = "example-app"
|
||||
client_secret = "example-app-secret"
|
||||
callback_url = "http://localhost:8000/oidc-callback"
|
||||
EOF
|
||||
export ROCKET_PROFILE=test
|
||||
export ROCKET_CONFIG
|
||||
cargo test "$@"
|
||||
541
src/auth.rs
Normal file
541
src/auth.rs
Normal file
@@ -0,0 +1,541 @@
|
||||
use std::path::Path;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use jsonwebtoken::{
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
|
||||
};
|
||||
use openidconnect::core::{
|
||||
CoreAuthenticationFlow, CoreClient, CoreIdTokenClaims,
|
||||
CoreProviderMetadata,
|
||||
};
|
||||
use openidconnect::{
|
||||
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||
EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
|
||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
|
||||
Scope, TokenResponse,
|
||||
};
|
||||
use rocket::config::Config as RocketConfig;
|
||||
use rocket::http::{Cookie, CookieJar, SameSite, Status};
|
||||
use rocket::request::{FromRequest, Outcome, Request};
|
||||
use rocket::response::Redirect;
|
||||
use rocket::time::{Duration, OffsetDateTime};
|
||||
use rocket::State;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, error, trace, warn};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::error::{LoginError, OidcError};
|
||||
use crate::Context;
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
struct UserClaims {
|
||||
aud: String,
|
||||
exp: u64,
|
||||
iat: u64,
|
||||
iss: String,
|
||||
nbf: u64,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
/// Represents an authenticated user
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn new(id: &'static str) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the user ID
|
||||
///
|
||||
/// The user ID is an arbitrary string assigned by the identity
|
||||
/// provider. It does NOT necessarily represent a username, email
|
||||
/// address, or any other name; it is simply a unique identifier.
|
||||
pub fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for User {
|
||||
type Error = &'static str;
|
||||
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let secret = &req.rocket().state::<Context>().unwrap().jwt_secret;
|
||||
match req.cookies().get("auth.token") {
|
||||
Some(c) => match UserClaims::from_jwt(c.value(), secret) {
|
||||
Ok(d) => Outcome::Success(d.into()),
|
||||
Err(e) => {
|
||||
debug!("Invalid auth token: {}", e);
|
||||
Outcome::Error((Status::Unauthorized, "Unauthorized"))
|
||||
},
|
||||
},
|
||||
None => match req.headers().get_one("Authorization") {
|
||||
Some(s) => match extract_token(secret, s) {
|
||||
Ok(d) => Outcome::Success(d.into()),
|
||||
Err(e) => Outcome::Error((Status::Unauthorized, e)),
|
||||
},
|
||||
None => Outcome::Error((Status::Unauthorized, "Unauthorized")),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserClaims> for User {
|
||||
fn from(claims: UserClaims) -> Self {
|
||||
Self { id: claims.sub }
|
||||
}
|
||||
}
|
||||
|
||||
impl UserClaims {
|
||||
fn to_jwt(
|
||||
&self,
|
||||
secret: &[u8],
|
||||
) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
let k = EncodingKey::from_secret(secret);
|
||||
encode(&Header::default(), self, &k)
|
||||
}
|
||||
|
||||
fn from_jwt(
|
||||
token: &str,
|
||||
secret: &[u8],
|
||||
) -> Result<Self, jsonwebtoken::errors::Error> {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
v.validate_nbf = true;
|
||||
v.set_issuer(&[env!("CARGO_PKG_NAME")]);
|
||||
v.set_audience(&[env!("CARGO_PKG_NAME")]);
|
||||
let k = DecodingKey::from_secret(secret);
|
||||
Ok(decode(token, &k, &v)?.claims)
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Create a signed JWT for a user
|
||||
pub fn make_jwt(
|
||||
&self,
|
||||
user: &User,
|
||||
ttl: u64,
|
||||
) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + ttl,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: user.id().into(),
|
||||
};
|
||||
claims.to_jwt(&self.jwt_secret)
|
||||
}
|
||||
|
||||
/// Decode the [`User`]
|
||||
pub fn decode_jwt(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Result<User, jsonwebtoken::errors::Error> {
|
||||
Ok(UserClaims::from_jwt(token, &self.jwt_secret)?.into())
|
||||
}
|
||||
}
|
||||
|
||||
pub type OidcClient = CoreClient<
|
||||
EndpointSet,
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointMaybeSet,
|
||||
EndpointMaybeSet,
|
||||
>;
|
||||
|
||||
/// Represents the OIDC login state, which is stored in a cookie
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OidcClientState {
|
||||
pkce: String,
|
||||
csrf: String,
|
||||
nonce: String,
|
||||
}
|
||||
|
||||
/// Error response returned to the client when OIDC login fails
|
||||
#[derive(Debug, rocket::Responder)]
|
||||
#[response(status = 401)]
|
||||
pub struct LoginFailed(String);
|
||||
|
||||
impl From<LoginError> for LoginFailed {
|
||||
fn from(e: LoginError) -> Self {
|
||||
error!("Login failed: {}", e);
|
||||
Self(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an openidconnect Client instance from the given Config
|
||||
pub async fn get_oidc_client(
|
||||
discovery_url: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
callback_url: String,
|
||||
http_client: &reqwest::Client,
|
||||
) -> Result<OidcClient, OidcError> {
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(discovery_url)?,
|
||||
http_client,
|
||||
)
|
||||
.await?;
|
||||
Ok(CoreClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(client_id),
|
||||
client_secret.map(ClientSecret::new),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(callback_url)?))
|
||||
}
|
||||
|
||||
/// Load a secret from the specified path
|
||||
pub fn load_secret<P: AsRef<Path>>(path: P) -> Result<Vec<u8>, std::io::Error> {
|
||||
let mut secret = vec![];
|
||||
let mut f = std::fs::File::open(path)?;
|
||||
f.read_to_end(&mut secret)?;
|
||||
Ok(secret)
|
||||
}
|
||||
|
||||
/// Extract a JWT from the Authorization HTTP request header
|
||||
fn extract_token(
|
||||
secret: &[u8],
|
||||
header: &str,
|
||||
) -> Result<UserClaims, &'static str> {
|
||||
let mut iter = header.split_ascii_whitespace();
|
||||
let scheme = iter.next();
|
||||
let token = iter.next();
|
||||
if token.is_some() && Some("Bearer") != scheme {
|
||||
return Err("Unsupported authorization scheme");
|
||||
}
|
||||
if let Some(token) = token {
|
||||
Ok(UserClaims::from_jwt(token, secret).map_err(|e| {
|
||||
error!("Failed to decode auth token: {}", e);
|
||||
"Invalid token"
|
||||
})?)
|
||||
} else {
|
||||
Err("Invalid Authorization header")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the authorization code from OIDC callback parameters
|
||||
///
|
||||
/// If the callback indicates an authorization failure, or the OAuth2
|
||||
/// state is invalid, a [`LoginError`] is returned.
|
||||
fn get_auth_code(
|
||||
csrf_token: &str,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<String, LoginError> {
|
||||
let state = params.get("state").ok_or(LoginError::MissingStateParam)?;
|
||||
if state != csrf_token {
|
||||
return Err(LoginError::InvalidCsrfState);
|
||||
}
|
||||
if let Some(error) = params.get("error") {
|
||||
let msg = if let Some(err_desc) = params.get("error_description") {
|
||||
format!("Error handling OIDC callback ({}): {}", error, err_desc)
|
||||
} else {
|
||||
format!("Error handling OIDC callback: {}", error)
|
||||
};
|
||||
error!("{}", msg);
|
||||
return Err(LoginError::IdpError(msg));
|
||||
}
|
||||
Ok(params
|
||||
.get("code")
|
||||
.ok_or(LoginError::MissingAuthCode)?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Exchange an OIDC Authorization code for an Identity Token
|
||||
async fn exchange_code(
|
||||
code: String,
|
||||
state: OidcClientState,
|
||||
ctx: &Context,
|
||||
) -> Result<CoreIdTokenClaims, LoginError> {
|
||||
let pkce_verifier = PkceCodeVerifier::new(state.pkce);
|
||||
debug!("Exchanging authorization code for access token");
|
||||
let http_client = &ctx.oidc_http_client;
|
||||
let oidc = ctx.oidc().await?;
|
||||
let token_response = oidc
|
||||
.exchange_code(AuthorizationCode::new(code))?
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(http_client)
|
||||
.await
|
||||
.map_err(LoginError::TokenRequestError)?;
|
||||
debug!(
|
||||
"Received response token type {:?}",
|
||||
token_response.token_type()
|
||||
);
|
||||
debug!("Access token: {}", token_response.access_token().secret());
|
||||
trace!("Token response: {:?}", token_response);
|
||||
|
||||
let id_token = token_response
|
||||
.id_token()
|
||||
.ok_or(LoginError::MissingIdToken)?;
|
||||
let id_token_verifier = oidc.id_token_verifier();
|
||||
let claims = id_token
|
||||
.claims(&oidc.id_token_verifier(), &Nonce::new(state.nonce))?;
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
id_token.signing_alg()?,
|
||||
id_token.signing_key(&id_token_verifier)?,
|
||||
)?;
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
return Err(LoginError::InvalidAccessToken);
|
||||
}
|
||||
}
|
||||
Ok(claims.clone())
|
||||
}
|
||||
|
||||
/// Initiate OpenID Connect login flow
|
||||
#[rocket::get("/login")]
|
||||
pub async fn oidc_login(
|
||||
ctx: &State<Context>,
|
||||
cookies: &CookieJar<'_>,
|
||||
config: &RocketConfig,
|
||||
) -> Result<Redirect, LoginFailed> {
|
||||
let (pkce_challenge, pkce_verifier) =
|
||||
PkceCodeChallenge::new_random_sha256();
|
||||
trace!("PKCE: {:?} {:?}", pkce_challenge, pkce_verifier);
|
||||
|
||||
let oidc = ctx.oidc().await.map_err(LoginError::from)?;
|
||||
let (auth_url, csrf_token, nonce) = oidc
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.add_scope(Scope::new("openid".into()))
|
||||
.add_scope(Scope::new("profile".into()))
|
||||
.add_scope(Scope::new("email".into()))
|
||||
.add_scope(Scope::new("groups".into()))
|
||||
.url();
|
||||
trace!(
|
||||
"CSRF token: {}, nonce: {}",
|
||||
csrf_token.secret(),
|
||||
nonce.secret()
|
||||
);
|
||||
let state = OidcClientState {
|
||||
pkce: pkce_verifier.into_secret(),
|
||||
csrf: csrf_token.into_secret(),
|
||||
nonce: nonce.secret().to_string(),
|
||||
};
|
||||
match rocket::serde::json::to_string(&state) {
|
||||
Ok(s) => {
|
||||
cookies.add_private(
|
||||
Cookie::build(("oidc.state", s))
|
||||
.secure(config.profile != RocketConfig::DEBUG_PROFILE)
|
||||
.same_site(
|
||||
if config.profile == RocketConfig::DEBUG_PROFILE {
|
||||
SameSite::Lax
|
||||
} else {
|
||||
SameSite::Strict
|
||||
},
|
||||
)
|
||||
.http_only(true)
|
||||
.expires(None)
|
||||
.build(),
|
||||
);
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to serialize OIDC client state: {}", e);
|
||||
},
|
||||
};
|
||||
Ok(Redirect::to(auth_url.to_string()))
|
||||
}
|
||||
|
||||
/// Handle OpenID Connect authorization callback
|
||||
#[rocket::get("/oidc-callback?<params..>")]
|
||||
pub async fn oidc_callback(
|
||||
params: HashMap<String, String>,
|
||||
cookies: &CookieJar<'_>,
|
||||
ctx: &State<Context>,
|
||||
config: &State<Config>,
|
||||
) -> Result<Redirect, LoginFailed> {
|
||||
trace!("{:?}", params);
|
||||
let state = cookies
|
||||
.get_private("oidc.state")
|
||||
.inspect(|c| cookies.remove(c.clone()))
|
||||
.ok_or(LoginError::MissingClientState)?;
|
||||
let state: OidcClientState = rocket::serde::json::from_str(state.value())
|
||||
.map_err(|e| {
|
||||
error!("Failed to parse OIDC client state: {}", e);
|
||||
LoginError::InvalidClientState
|
||||
})?;
|
||||
let code = get_auth_code(&state.csrf, ¶ms)?;
|
||||
trace!("Got authorization code: {}", code);
|
||||
let claims = exchange_code(code, state, ctx).await?;
|
||||
let user = User {
|
||||
id: claims.subject().to_string(),
|
||||
};
|
||||
let expires = OffsetDateTime::now_utc()
|
||||
+ Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else(
|
||||
|e| {
|
||||
warn!("Invalid login TTL value: {}", e);
|
||||
i64::MAX
|
||||
},
|
||||
));
|
||||
cookies.add(
|
||||
Cookie::build((
|
||||
"auth.token",
|
||||
ctx.make_jwt(&user, config.auth.login_ttl)
|
||||
.map_err(LoginError::from)?,
|
||||
))
|
||||
.secure(true)
|
||||
.http_only(true)
|
||||
.expires(expires)
|
||||
.build(),
|
||||
);
|
||||
Ok(Redirect::to("/"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
static SECRET: &[u8; 32] = b"6gmrLQ4PGjeUpT3Xs48thx9Cu6XE5pgD";
|
||||
|
||||
fn now() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn make_token(claims: &UserClaims) -> String {
|
||||
claims.to_jwt(SECRET).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token() {
|
||||
let now = now();
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let claims = extract_token(SECRET, &header).unwrap();
|
||||
assert_eq!(claims, claims);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_expired() {
|
||||
let now = now() - 600;
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let err = extract_token(SECRET, &header).unwrap_err();
|
||||
assert_eq!(err, "Invalid token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_future() {
|
||||
let now = now() + 600;
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let err = extract_token(SECRET, &header).unwrap_err();
|
||||
assert_eq!(err, "Invalid token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bad_issuer() {
|
||||
let now = now();
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: "mallory".into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let err = extract_token(SECRET, &header).unwrap_err();
|
||||
assert_eq!(err, "Invalid token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bad_aud() {
|
||||
let now = now();
|
||||
let claims = UserClaims {
|
||||
aud: "mallory".into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let header = format!("Bearer {}", jwt);
|
||||
let err = extract_token(SECRET, &header).unwrap_err();
|
||||
assert_eq!(err, "Invalid token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_malformed_header() {
|
||||
let now = now();
|
||||
let claims = UserClaims {
|
||||
aud: "mallory".into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = make_token(&claims);
|
||||
let err = extract_token(SECRET, &jwt).unwrap_err();
|
||||
assert_eq!(err, "Invalid Authorization header");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_auth_scheme() {
|
||||
let header = "Basic dXNlcm5hbWU6cGFzc3dvcmQ=";
|
||||
let err = extract_token(SECRET, header).unwrap_err();
|
||||
assert_eq!(err, "Unsupported authorization scheme");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_userclaims_jwt_round_trip() {
|
||||
let now = now();
|
||||
let claims = UserClaims {
|
||||
aud: env!("CARGO_PKG_NAME").into(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: env!("CARGO_PKG_NAME").into(),
|
||||
nbf: now - 60,
|
||||
sub: "test1".into(),
|
||||
};
|
||||
let jwt = claims.to_jwt(SECRET).unwrap();
|
||||
let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap();
|
||||
assert_eq!(claims, claims1);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -9,11 +11,32 @@ pub struct MeilisearchConfig {
|
||||
pub index: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthConfig {
|
||||
pub jwt_secret: PathBuf,
|
||||
#[serde(default = "default_login_ttl")]
|
||||
pub login_ttl: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct OidcConfig {
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub callback_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
pub auth: AuthConfig,
|
||||
pub oidc: OidcConfig,
|
||||
pub meilisearch: MeilisearchConfig,
|
||||
}
|
||||
|
||||
fn default_index() -> String {
|
||||
env!("CARGO_PKG_NAME").into()
|
||||
}
|
||||
|
||||
fn default_login_ttl() -> u64 {
|
||||
604800
|
||||
}
|
||||
|
||||
68
src/error.rs
Normal file
68
src/error.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use openidconnect::core::CoreErrorResponseType;
|
||||
use openidconnect::url::ParseError;
|
||||
use openidconnect::{
|
||||
ClaimsVerificationError, ConfigurationError, DiscoveryError,
|
||||
HttpClientError, RequestTokenError, SignatureVerificationError,
|
||||
SigningError, StandardErrorResponse,
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OidcError {
|
||||
#[error("Invalid URL {0}")]
|
||||
Url(#[from] ParseError),
|
||||
#[error("OIDC Discovery failed: {0}")]
|
||||
Discovery(#[from] DiscoveryError<HttpClientError<reqwest::Error>>),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LoginError {
|
||||
#[error("Missing state parameter")]
|
||||
MissingStateParam,
|
||||
#[error("Invalid state parameter")]
|
||||
InvalidCsrfState,
|
||||
#[error("Missing OIDC client state cookie")]
|
||||
MissingClientState,
|
||||
#[error("Invalid OIDC client state cookie")]
|
||||
InvalidClientState,
|
||||
#[error("{0}")]
|
||||
IdpError(String),
|
||||
#[error("Missing OAuth2 authorization code")]
|
||||
MissingAuthCode,
|
||||
#[error("OIDC configuration error: {0}")]
|
||||
ConfigurationError(#[from] ConfigurationError),
|
||||
#[error("Token request error: {0}")]
|
||||
TokenRequestError(
|
||||
RequestTokenError<
|
||||
HttpClientError<reqwest::Error>,
|
||||
StandardErrorResponse<CoreErrorResponseType>,
|
||||
>,
|
||||
),
|
||||
#[error("Server did not return an ID token")]
|
||||
MissingIdToken,
|
||||
#[error("Server returned an invalid ID token")]
|
||||
InvalidIdToken,
|
||||
#[error("Invalid token claims: {0}")]
|
||||
ClaimsVerificationError(#[from] ClaimsVerificationError),
|
||||
#[error("Signature verification error: {0}")]
|
||||
SignatureVerificationError(#[from] SignatureVerificationError),
|
||||
#[error("Token signature error: {0}")]
|
||||
SigningError(#[from] SigningError),
|
||||
#[error("Invalid access token")]
|
||||
InvalidAccessToken,
|
||||
#[error("JWT serialization error: {0}")]
|
||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||
#[error("Invalid OIDC configuration: {0}")]
|
||||
Oidc(#[from] OidcError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum InitError {
|
||||
#[error("Meilisearch error: {0}")]
|
||||
Meilisearch(#[from] crate::meilisearch::Error),
|
||||
#[error("Failed to load JWT secret: {0}")]
|
||||
LoadJwtSecret(std::io::Error),
|
||||
#[error("Failed to load OIDC client secret: {0}")]
|
||||
LoadOidcSecret(std::io::Error),
|
||||
#[error("Failed to initialize HTTP client: {0}")]
|
||||
ReqwestError(#[from] reqwest::Error)
|
||||
}
|
||||
109
src/lib.rs
Normal file
109
src/lib.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
mod error;
|
||||
pub mod meilisearch;
|
||||
pub mod page;
|
||||
|
||||
use meilisearch_sdk::client::Client as MeilisearchClient;
|
||||
use rocket::fairing::{self, AdHoc};
|
||||
use rocket::Rocket;
|
||||
use tracing::error;
|
||||
|
||||
use config::Config;
|
||||
pub use error::InitError;
|
||||
|
||||
pub struct Context {
|
||||
client: MeilisearchClient,
|
||||
jwt_secret: Vec<u8>,
|
||||
oidc: config::OidcConfig,
|
||||
oidc_client_secret: Option<String>,
|
||||
oidc_http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn init(config: &Config) -> Result<Self, InitError> {
|
||||
let client = MeilisearchClient::try_from(config)?;
|
||||
let jwt_secret = auth::load_secret(&config.auth.jwt_secret)
|
||||
.map_err(InitError::LoadJwtSecret)?;
|
||||
let oidc = config.oidc.clone();
|
||||
let oidc_client_secret = match &config.oidc.client_secret {
|
||||
Some(p) => match auth::load_secret(p) {
|
||||
Ok(s) => Some(String::from_utf8_lossy(&s).trim().to_string()),
|
||||
Err(e) => return Err(InitError::LoadOidcSecret(e)),
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
let oidc_http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
jwt_secret,
|
||||
oidc,
|
||||
oidc_client_secret,
|
||||
oidc_http_client,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn oidc(&self) -> Result<auth::OidcClient, error::OidcError> {
|
||||
auth::get_oidc_client(
|
||||
self.oidc.discovery_url.clone(),
|
||||
self.oidc.client_id.clone(),
|
||||
self.oidc_client_secret.clone(),
|
||||
self.oidc.callback_url.clone(),
|
||||
&self.oidc_http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the application context
|
||||
async fn init_context(rocket: Rocket<rocket::Build>) -> fairing::Result {
|
||||
let config: &Config = rocket.state().unwrap();
|
||||
let ctx = match Context::init(config).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
eprintln!("Could not initialize application context: {}", e);
|
||||
return Err(rocket);
|
||||
},
|
||||
};
|
||||
// Use the JWT secret as the Rocket secret key, for encrypted cookies
|
||||
let figment = rocket
|
||||
.figment()
|
||||
.clone()
|
||||
.join(("secret_key", &ctx.jwt_secret));
|
||||
Ok(rocket.configure(figment).manage(ctx))
|
||||
}
|
||||
|
||||
/// Set up Meilisearch
|
||||
async fn meilisearch_setup(rocket: Rocket<rocket::Build>) -> fairing::Result {
|
||||
let config: &Config = rocket.state().unwrap();
|
||||
let ctx: &Context = match rocket.state() {
|
||||
Some(c) => c,
|
||||
None => return Err(rocket),
|
||||
};
|
||||
let client = &ctx.client;
|
||||
if let Err(e) =
|
||||
meilisearch::ensure_index(client, &config.meilisearch.index).await
|
||||
{
|
||||
error!("Failed to create Meilisearch index: {}", e);
|
||||
Err(rocket)
|
||||
} else {
|
||||
Ok(rocket)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rocket() -> Rocket<rocket::Build> {
|
||||
rocket::build()
|
||||
.mount(
|
||||
"/",
|
||||
rocket::routes![
|
||||
auth::oidc_callback,
|
||||
auth::oidc_login,
|
||||
page::post_page
|
||||
],
|
||||
)
|
||||
.attach(AdHoc::config::<Config>())
|
||||
.attach(AdHoc::try_on_ignite("Initialize context", init_context))
|
||||
.attach(AdHoc::try_on_ignite("Meilisearch Setup", meilisearch_setup))
|
||||
}
|
||||
83
src/main.rs
83
src/main.rs
@@ -1,88 +1,9 @@
|
||||
mod config;
|
||||
mod meilisearch;
|
||||
mod page;
|
||||
|
||||
use meilisearch_sdk::client::Client as MeilisearchClient;
|
||||
use rocket::fairing::{self, AdHoc};
|
||||
use rocket::form::Form;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{Rocket, State};
|
||||
use tracing::error;
|
||||
|
||||
use config::Config;
|
||||
|
||||
struct Context {
|
||||
#[allow(dead_code)]
|
||||
config: Config,
|
||||
client: MeilisearchClient,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum InitError {
|
||||
#[error("Meilisearch error: {0}")]
|
||||
Meilisearch(#[from] meilisearch::Error),
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn init(config: Config) -> Result<Self, InitError> {
|
||||
let client = MeilisearchClient::try_from(&config)?;
|
||||
Ok(Self { config, client })
|
||||
}
|
||||
}
|
||||
|
||||
/// Save page form
|
||||
#[derive(rocket::FromForm)]
|
||||
pub struct SavePageForm {
|
||||
/// Page URL
|
||||
url: String,
|
||||
/// Page content (SingleFile HTML)
|
||||
data: String,
|
||||
}
|
||||
|
||||
/// Save a visited page in SingleFile format
|
||||
#[rocket::post("/save", data = "<form>")]
|
||||
async fn save_page(
|
||||
form: Form<SavePageForm>,
|
||||
ctx: &State<Context>,
|
||||
) -> Result<Json<page::Page>, String> {
|
||||
match page::save_page(&form.url, &form.data, ctx).await {
|
||||
Ok(p) => Ok(Json(p)),
|
||||
Err(e) => {
|
||||
error!("Failed to save page: {}", e);
|
||||
Err(e.to_string())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Set up Meilisearch
|
||||
async fn meilisearch_setup(rocket: Rocket<rocket::Build>) -> fairing::Result {
|
||||
let ctx: &Context = rocket.state().unwrap();
|
||||
let client = &ctx.client;
|
||||
let config = &ctx.config;
|
||||
if let Err(e) =
|
||||
meilisearch::ensure_index(client, &config.meilisearch.index).await
|
||||
{
|
||||
error!("Failed to create Meilisearch index: {}", e);
|
||||
Err(rocket)
|
||||
} else {
|
||||
Ok(rocket)
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::launch]
|
||||
async fn rocket() -> _ {
|
||||
fn rocket() -> _ {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
let rocket = rocket::build();
|
||||
|
||||
let config: Config = rocket.figment().extract().unwrap();
|
||||
let ctx = Context::init(config).unwrap();
|
||||
|
||||
rocket
|
||||
.manage(ctx)
|
||||
.mount("/", rocket::routes![save_page])
|
||||
.attach(AdHoc::try_on_ignite("Meilisearch Setup", meilisearch_setup))
|
||||
seensite::rocket()
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ impl TryFrom<&Config> for Client {
|
||||
Some(t) => Some(std::fs::read_to_string(t).map_err(Error::Token)?),
|
||||
None => None,
|
||||
};
|
||||
Ok(Client::new(&config.meilisearch.url, token)?)
|
||||
Ok(Client::new(&config.meilisearch.url, token.as_deref().map(str::trim))?)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
45
src/page.rs
45
src/page.rs
@@ -4,8 +4,14 @@ use html5ever::tendril::TendrilSink;
|
||||
use markup5ever_rcdom::{Handle, NodeData, RcDom};
|
||||
use meilisearch_sdk::errors::Error;
|
||||
use rand::Rng;
|
||||
use rocket::form::Form;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::State;
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, error, event, span, Level};
|
||||
|
||||
use crate::auth::User;
|
||||
use crate::config::Config;
|
||||
use crate::Context;
|
||||
|
||||
static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
@@ -15,6 +21,8 @@ static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
pub struct Page {
|
||||
/// Unique saved page ID
|
||||
id: String,
|
||||
/// User ID of page owner
|
||||
user_id: String,
|
||||
/// Visit timestamp
|
||||
timestamp: DateTime<Utc>,
|
||||
/// Page URL
|
||||
@@ -25,22 +33,55 @@ pub struct Page {
|
||||
data: String,
|
||||
}
|
||||
|
||||
/// Save page form
|
||||
#[derive(rocket::FromForm)]
|
||||
pub struct SavePageForm {
|
||||
/// Page URL
|
||||
url: String,
|
||||
/// Page content (SingleFile HTML)
|
||||
data: String,
|
||||
}
|
||||
|
||||
/// Save a visited page in SingleFile format
|
||||
#[rocket::post("/save", data = "<form>")]
|
||||
pub async fn post_page(
|
||||
user: User,
|
||||
form: Form<SavePageForm>,
|
||||
ctx: &State<Context>,
|
||||
config: &State<Config>,
|
||||
) -> Result<Json<Page>, String> {
|
||||
match save_page(&form.url, &form.data, ctx, config, user.id()).await {
|
||||
Ok(p) => Ok(Json(p)),
|
||||
Err(e) => {
|
||||
error!("Failed to save page: {}", e);
|
||||
Err(e.to_string())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Save the page
|
||||
pub async fn save_page(
|
||||
url: &str,
|
||||
data: &str,
|
||||
ctx: &Context,
|
||||
config: &Config,
|
||||
user: &str,
|
||||
) -> Result<Page, Error> {
|
||||
let client = &ctx.client;
|
||||
let index = client.get_index(&ctx.config.meilisearch.index).await?;
|
||||
let span = span!(Level::INFO, "save_page", url = url, user = user);
|
||||
let _guard = span.enter();
|
||||
let index_name = &config.meilisearch.index;
|
||||
debug!("Saving page in Meilisearch index {}", index_name);
|
||||
let index = ctx.client.get_index(index_name).await?;
|
||||
let doc = Page {
|
||||
id: gen_id(),
|
||||
user_id: user.into(),
|
||||
timestamp: Utc::now(),
|
||||
url: url.into(),
|
||||
title: extract_title(data),
|
||||
data: data.into(),
|
||||
};
|
||||
index.add_or_replace(&[doc.clone()], Some("id")).await?;
|
||||
event!(Level::INFO, "Saved page {}", doc.id);
|
||||
Ok(doc)
|
||||
}
|
||||
|
||||
|
||||
86
tests/integration/auth.rs
Normal file
86
tests/integration/auth.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use rocket::http::Status;
|
||||
use rocket::local::asynchronous::Client;
|
||||
use rocket::uri;
|
||||
use scraper::{Html, Selector};
|
||||
use tracing::debug;
|
||||
|
||||
use seensite::auth::*;
|
||||
use seensite::Context;
|
||||
|
||||
#[rocket::async_test]
|
||||
async fn test_login() {
|
||||
super::setup();
|
||||
|
||||
let client = Client::tracked(seensite::rocket()).await.unwrap();
|
||||
let dex = reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.cookie_store(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let ctx: &Context = client.rocket().state().unwrap();
|
||||
|
||||
// First, initiate the login process
|
||||
let req = client.get(uri![oidc_login]);
|
||||
let res = req.dispatch().await;
|
||||
assert_eq!(res.status(), Status::SeeOther);
|
||||
let mut location = res.headers().get_one("Location").unwrap().to_string();
|
||||
|
||||
// Next, follow the redirect URL provided by the login page.
|
||||
// This will redirect again.
|
||||
let res = loop {
|
||||
debug!("Redirect: {}", location);
|
||||
let res = dex.get(location).send().await.unwrap();
|
||||
if res.status() == reqwest::StatusCode::FOUND {
|
||||
let base_url = res.url().clone();
|
||||
location = base_url
|
||||
.join(res.headers().get("Location").unwrap().to_str().unwrap())
|
||||
.unwrap()
|
||||
.to_string();
|
||||
continue;
|
||||
}
|
||||
break res;
|
||||
};
|
||||
|
||||
// After all the redirects, we end up on the IdP login form.
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK);
|
||||
// Obtain the login form target
|
||||
let base_url = res.url().clone();
|
||||
let body = res.text().await;
|
||||
let doc = Html::parse_fragment(&body.unwrap());
|
||||
let sel = Selector::parse("form").unwrap();
|
||||
let form = doc.select(&sel).next().unwrap();
|
||||
let action = form.attr("action").unwrap();
|
||||
let url = base_url.join(action).unwrap();
|
||||
// Post the user credentials to the IdP login form
|
||||
let res = dex
|
||||
.post(url)
|
||||
.form(&[("login", "user@example.com"), ("password", "password")])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), reqwest::StatusCode::SEE_OTHER);
|
||||
|
||||
// The result of the IdP login form submission is another redirect
|
||||
// to the OIDC callback of our application.
|
||||
let location = reqwest::Url::parse(
|
||||
res.headers().get("Location").unwrap().to_str().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let callback =
|
||||
format!("{}?{}", location.path(), location.query().unwrap());
|
||||
debug!("Callback: {}", callback);
|
||||
|
||||
// Finally, make the callback request to finish the login process.
|
||||
let res = client.get(callback).dispatch().await;
|
||||
assert_eq!(res.status(), Status::SeeOther);
|
||||
let location = res.headers().get_one("Location").unwrap().to_string();
|
||||
assert_eq!(location, "/");
|
||||
let cookie = res.cookies().get("auth.token").unwrap();
|
||||
debug!("Cookie: {:?}", cookie);
|
||||
|
||||
// Check to ensure the cookie contains a valid token
|
||||
let user = ctx.decode_jwt(cookie.value()).unwrap();
|
||||
debug!("User: {:?}", user);
|
||||
assert!(!user.id().is_empty());
|
||||
}
|
||||
22
tests/integration/main.rs
Normal file
22
tests/integration/main.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
mod auth;
|
||||
mod page;
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static SETUP: LazyLock<()> = LazyLock::new(|| {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::new(
|
||||
concat!(
|
||||
env!("CARGO_PKG_NAME"),
|
||||
"=trace,",
|
||||
"seensite=trace,",
|
||||
"debug",
|
||||
)
|
||||
))
|
||||
.with_test_writer()
|
||||
.init();
|
||||
});
|
||||
|
||||
fn setup() {
|
||||
LazyLock::force(&SETUP);
|
||||
}
|
||||
65
tests/integration/page.rs
Normal file
65
tests/integration/page.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use form_urlencoded::Serializer;
|
||||
use rocket::http::Status;
|
||||
use rocket::http::{ContentType, Header};
|
||||
use rocket::local::blocking::Client;
|
||||
use rocket::serde::json::Value;
|
||||
use rocket::uri;
|
||||
|
||||
use seensite::auth::User;
|
||||
use seensite::page::*;
|
||||
use seensite::Context;
|
||||
|
||||
static TEST_URL: &str = r"http://example.org/page1.html";
|
||||
|
||||
static TEST_HTML: &str = r#"<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Example Page</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Example Page</title>
|
||||
<p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec malesuada,
|
||||
tellus eu fringilla finibus, turpis sapien faucibus elit, a fringilla dolor
|
||||
urna volutpat dui. Curabitur eget dui aliquet, gravida velit tempor, porta
|
||||
ipsum. Donec finibus orci quis velit tincidunt placerat. Aliquam erat volutpat.
|
||||
Nullam id nisl odio. Praesent egestas fringilla ultricies. Aenean blandit
|
||||
lectus mauris, quis auctor ipsum porttitor quis. Vivamus egestas cursus erat,
|
||||
et egestas diam volutpat eu. Vestibulum imperdiet purus ac turpis sodales, sit
|
||||
amet auctor risus lacinia. Duis feugiat lobortis orci quis sagittis.</p>
|
||||
</html>
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_post_page() {
|
||||
super::setup();
|
||||
let client = Client::tracked(seensite::rocket()).unwrap();
|
||||
let ctx: &Context = client.rocket().state().unwrap();
|
||||
let user = User::new("test1");
|
||||
let token = ctx.make_jwt(&user, 60).unwrap();
|
||||
let data = Serializer::new(String::new())
|
||||
.append_pair("url", TEST_URL)
|
||||
.append_pair("data", TEST_HTML)
|
||||
.finish();
|
||||
let req = client
|
||||
.post(uri![post_page])
|
||||
.header(ContentType::Form)
|
||||
.header(Header::new("Authorization", format!("Bearer {}", token)))
|
||||
.body(&data);
|
||||
let res = req.dispatch();
|
||||
assert_eq!(res.status(), Status::Ok);
|
||||
let page = res.into_json::<Value>().unwrap();
|
||||
assert_eq!(page.get("title").unwrap().as_str().unwrap(), "Example Page");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_post_page_unauth() {
|
||||
super::setup();
|
||||
let client = Client::tracked(seensite::rocket()).unwrap();
|
||||
let data = Serializer::new(String::new())
|
||||
.append_pair("url", TEST_URL)
|
||||
.append_pair("data", TEST_HTML)
|
||||
.finish();
|
||||
let req = client.post(uri![post_page]).body(&data);
|
||||
let res = req.dispatch();
|
||||
assert_eq!(res.status(), Status::Unauthorized);
|
||||
}
|
||||
Reference in New Issue
Block a user