Compare commits

...

10 Commits

Author SHA1 Message Date
dbd9165626 context: Do not cache OIDC client
It turns out, we do NOT want to keep one single, global OIDC client data
structure.  There are two major problems with this:

1. If the OIDC IdP happens to be unavailable when the process starts,
   Rocket will fail to ignite and the process will exit.  This is
   unnecessary, since the only functionality that will be unavailable
   without the IdP is new logins; existing sessions/tokens will still be
   valid.
2. Identity providers can change keys, URLs, etc. at any time.  If we
   cache everything and never look it up again, all future login
   attempts will fail until the server is restarted.

The official recommendation for caching OIDC IdP configuration and keys
is to use native HTTP cache control.  Unfortunately, most IdPs
explicitly disable caching of their HTTP responses.
2025-04-08 21:40:20 -05:00
a50dca7fae auth: Introduce User struct
The `UserClaims` structure is an implementation detail of how the JWT
encoding process works.  We do not need to expose all of the details of
the JWT, such as issuer, audience, expiration, etc. to rest of the
application.  Route handlers should only be concerned with the
information about the user, rather than the metadata about how the user
was authenticated.
2025-04-08 20:29:47 -05:00
3d2772cfc8 auth: Implement OpenID Connect login flow
This commit adds two path operations, *GET /login* and *GET
/oidc-callback*, which initiate and complete the OpenID connect login
flow, respectively.  Only the *Authorization Code* flow is supported,
since this is the only flow implemented by Authelia.

There is quite a bit of boilerplate required to fully implement an OIDC
relying party, especially in Rust.  The documentation for
`openidconnect` is decent, but it still took quite a bit of trial and
error to get everything working.

After successfully finishing the OIDC login, the client will receive a
cookie containing a JWT that can be used for further communication with
the server.  We're not using the OIDC tokens themselves for
authorization.

For development and testing, Dex is a simple and convenient OIDC IdP.
The only caveat is its configuration file must contain list the TCP port
clients will use to connect to it, meaning we cannot use Podman dynamic
port allocation like we do for Meilisearch.  Ultimately, this just means
the integration tests will fail if there is another process already
listening on 5556.
2025-04-07 19:17:16 -05:00
42502694f3 error: Move errors to a separate module
For some reason, using the `thiserror::Error` derive macro causes the
syntax highlighting to fail for the rest of the code in the file, at
least in Neovim.  Having all the errors in one module will consolidate
this effect to that one file.
2025-04-05 17:08:51 -05:00
e1da535dc2 meilisearch: Trim whitespace from token
When reading the Meilisearch token from the file specified in the
configuration, we need to ensure any whitespace are trimmed from the
string.  If the token file was created with a text editor, or even a
shell pipeline, it's likely to have a trailing newline character.  If we
do not remove this, authenticated requests to Meilisearch will fail.
2025-04-05 17:08:51 -05:00
c227bec62d run-tests: Add harness for integration tests
The `run-tests.sh` script sets up a full environment for the integration
tests.  This includes starting Meilisearch (with a master key to enable
authentication) and generating an ephemeral JWT secret.  After the tests
are run, the environment is cleaned up.

```sh
just test

just unit-tests

just integration-tests
```
2025-04-05 17:08:49 -05:00
76cf57ebe0 Begin integration tests
Refactoring the code a bit here to make the `Rocket` instance available
to the integration tests.  To do this, we have to convert to a library
crate (`lib.rs`) with an executable entry point (`main.rs`).  This
allows the tests, which are separate crates, to import types and
functions from the library.

Besides splitting the `rocket` function into two parts (one in `lib.rs`
that creates the `Rocket<Build>` and another in `main.rs` that becomes
the process entry point), I have reworked the initialization process to
make better use of Rocket's "fairings" feature.  We don't want to call
`process::exit()` in a test, so if there is a problem reading the
configuration or initializing the context, we need to report it to
Rocket instead.
2025-04-05 17:07:39 -05:00
720bb690ea auth: Initial JWT implementation
We'll use a JWT in the `Authorization` request header to identify the
user saving a page.  The token will need to be set in the _authorization
token_ field in the SingleFile configuration so it will be included when
uploading.
2025-04-05 17:07:39 -05:00
a1308507af main: Add better config/init error messages
The default messages printed when the process panics because the
configuration could not be loaded or the application context could not
be initialized are somewhat difficult to read.  Instead of calling
`unwrap` in these cases, we need to explicitly handle the errors and
print more appropriate messages.
2025-04-05 17:07:19 -05:00
df560b18f2 page: Instrument save_page function
Adding a span and event to track indexing time.
2025-04-05 17:07:19 -05:00
17 changed files with 2304 additions and 99 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,5 @@
/target
/Rocket.toml
/jwt.secret
/meilisearch.token
/oidc.secret

1179
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -6,11 +6,19 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4.40", default-features = false, features = ["std", "clock", "serde"] }
html5ever = "0.27.0"
jsonwebtoken = { version = "9.3.1", default-features = false }
markup5ever_rcdom = "0.3.0"
meilisearch-sdk = "0.28.0"
openidconnect = { version = "4.0.0", default-features = false, features = ["reqwest", "native-tls"] }
rand = "0.9.0"
rocket = { version = "0.5.1", features = ["json"] }
reqwest = { version = "0.12.15", features = ["json", "native-tls"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] }
serde = { version = "1.0.219", features = ["derive"] }
thiserror = "2.0.12"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
[dev-dependencies]
form_urlencoded = "1.2.1"
reqwest = { version = "0.12.15", features = ["cookies"] }
scraper = { version = "0.23.1", default-features = false }

9
Justfile Normal file
View File

@@ -0,0 +1,9 @@
# vim: set et :
integration-tests:
./run-tests.sh --test=integration
unit-tests:
cargo test --lib
test: unit-tests integration-tests

29
dex.yaml Normal file
View File

@@ -0,0 +1,29 @@
issuer: http://127.0.0.1:5556/dex
storage:
type: memory
web:
http: 0.0.0.0:5556
staticClients:
- id: example-app
redirectURIs:
- 'http://localhost:8000/oidc-callback'
name: 'Example App'
secret: example-app-secret
enablePasswordDB: true
staticPasswords:
- email: "user@example.com"
# bcrypt hash of "password"
hash: "$2y$10$BvrNnoSifaDAcS95zT5CyegMKE90S9gpiMHcJj82hVnyRYwn2LNwS"
username: "testuser"
userID: "1234"
oauth2:
skipApprovalScreen: true
logger:
level: debug

37
examples/make-token.rs Normal file
View File

@@ -0,0 +1,37 @@
use std::io::Read;
use std::time::SystemTime;
use jsonwebtoken::{encode, EncodingKey, Header};
use serde::Serialize;
#[derive(Debug, Serialize)]
struct UserClaims {
aud: String,
exp: u64,
iat: u64,
iss: String,
nbf: u64,
sub: String,
}
fn main() {
let args: Vec<_> = std::env::args().collect();
let mut secret = vec![];
let mut f = std::fs::File::open(&args[1]).unwrap();
f.read_to_end(&mut secret).unwrap();
let k = EncodingKey::from_secret(&secret);
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 604800,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: args[2].to_string(),
};
let jwt = encode(&Header::default(), &claims, &k).unwrap();
println!("{}", jwt);
}

93
run-tests.sh Executable file
View File

@@ -0,0 +1,93 @@
#!/bin/sh
# vim: set ts=4 sts=4 sw=4 et :
cleanup() {
if [ -n "${cid}" ]; then
echo 'Stopping Meilisearch ...' >&2
podman stop "${cid}" >/dev/null
fi
unset cid
if [ -f "${meilisearch_key}" ]; then
rm "${meilisearch_key}"
fi
unset meilisearch_key
if [ -n "${dex_cid}" ]; then
echo 'Stopping Dex ...' >&2
podman stop "${dex_cid}" >/dev/null
fi
unset cid
if [ -f "${jwt_key}" ]; then
rm "${jwt_key}"
fi
unset jwt_key
if [ -f "${ROCKET_CONFIG}" ]; then
rm "${ROCKET_CONFIG}"
fi
unset ROCKET_CONFIG
}
trap cleanup INT TERM QUIT EXIT
echo 'Generating Meilisearch master key ...' >&2
MEILI_MASTER_KEY=$(tr -cd 0-9A-Za-z_- < /dev/urandom | head -c 32)
echo 'Starting Meilisearch ...' >&2
cid=$(
export MEILI_MASTER_KEY
podman run --rm -d -P \
-e MEILI_NO_ANALYTICS=true \
-e MEILI_MASTER_KEY \
docker.io/getmeili/meilisearch:v1.13
)
port=$(podman port ${cid} 7700/tcp | cut -d: -f2)
printf 'Waiting for Meilisearch on port %d ...' "${port}" >&2
until curl -fs -o /dev/null 127.0.0.1:${port}; do
printf '.'
sleep 0.25
done
echo ' ready' >&2
meilisearch_key=$(mktemp meilisearch.token.XXXXXX)
curl -s 127.0.0.1:${port}/keys -H "Authorization: Bearer ${MEILI_MASTER_KEY}" \
| jq -r '.results[] | select(.name == "Default Admin API Key") | .key' \
> "${meilisearch_key}"
echo 'Starting Dex ...'
dex_cid=$(
podman run --rm -d \
-p 5556:5556 \
-v ./dex.yaml:/etc/dex/config.yaml:ro,z \
ghcr.io/dexidp/dex \
dex serve /etc/dex/config.yaml
) || exit
dex_port=5556
printf 'Waiting for Dex on port %d ...' "${dex_port}" >&2
until curl -fs -o /dev/null 127.0.0.1:${dex_port}/dex/healthz; do
printf '.'
sleep 0.25
done
echo ' ready' >&2
echo 'Generating JWT secret key ...' >&2
jwt_key=$(mktemp jwt.secret.XXXXXX)
tr -cd 0-9A-Za-z_- < /dev/urandom | head -c 32 > "${jwt_key}"
ROCKET_CONFIG=$(mktemp Rocket.toml.XXXXXX)
cat > "${ROCKET_CONFIG}" <<EOF
[test.meilisearch]
url = "http://127.0.0.1:${port}"
token = "${meilisearch_key}"
[test.auth]
jwt_secret = "${jwt_key}"
[test.oidc]
discovery_url = "http://127.0.0.1:${dex_port}/dex"
client_id = "example-app"
client_secret = "example-app-secret"
callback_url = "http://localhost:8000/oidc-callback"
EOF
export ROCKET_PROFILE=test
export ROCKET_CONFIG
cargo test "$@"

541
src/auth.rs Normal file
View File

@@ -0,0 +1,541 @@
use std::path::Path;
use std::collections::HashMap;
use std::io::Read;
use std::time::SystemTime;
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
};
use openidconnect::core::{
CoreAuthenticationFlow, CoreClient, CoreIdTokenClaims,
CoreProviderMetadata,
};
use openidconnect::{
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
Scope, TokenResponse,
};
use rocket::config::Config as RocketConfig;
use rocket::http::{Cookie, CookieJar, SameSite, Status};
use rocket::request::{FromRequest, Outcome, Request};
use rocket::response::Redirect;
use rocket::time::{Duration, OffsetDateTime};
use rocket::State;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, trace, warn};
use crate::config::Config;
use crate::error::{LoginError, OidcError};
use crate::Context;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct UserClaims {
aud: String,
exp: u64,
iat: u64,
iss: String,
nbf: u64,
sub: String,
}
/// Represents an authenticated user
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct User {
id: String,
}
impl User {
pub fn new(id: &'static str) -> Self {
Self {
id: id.into(),
}
}
/// Return the user ID
///
/// The user ID is an arbitrary string assigned by the identity
/// provider. It does NOT necessarily represent a username, email
/// address, or any other name; it is simply a unique identifier.
pub fn id(&self) -> &str {
&self.id
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for User {
type Error = &'static str;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let secret = &req.rocket().state::<Context>().unwrap().jwt_secret;
match req.cookies().get("auth.token") {
Some(c) => match UserClaims::from_jwt(c.value(), secret) {
Ok(d) => Outcome::Success(d.into()),
Err(e) => {
debug!("Invalid auth token: {}", e);
Outcome::Error((Status::Unauthorized, "Unauthorized"))
},
},
None => match req.headers().get_one("Authorization") {
Some(s) => match extract_token(secret, s) {
Ok(d) => Outcome::Success(d.into()),
Err(e) => Outcome::Error((Status::Unauthorized, e)),
},
None => Outcome::Error((Status::Unauthorized, "Unauthorized")),
},
}
}
}
impl From<UserClaims> for User {
fn from(claims: UserClaims) -> Self {
Self { id: claims.sub }
}
}
impl UserClaims {
fn to_jwt(
&self,
secret: &[u8],
) -> Result<String, jsonwebtoken::errors::Error> {
let k = EncodingKey::from_secret(secret);
encode(&Header::default(), self, &k)
}
fn from_jwt(
token: &str,
secret: &[u8],
) -> Result<Self, jsonwebtoken::errors::Error> {
let mut v = Validation::new(Algorithm::HS256);
v.validate_nbf = true;
v.set_issuer(&[env!("CARGO_PKG_NAME")]);
v.set_audience(&[env!("CARGO_PKG_NAME")]);
let k = DecodingKey::from_secret(secret);
Ok(decode(token, &k, &v)?.claims)
}
}
impl Context {
/// Create a signed JWT for a user
pub fn make_jwt(
&self,
user: &User,
ttl: u64,
) -> Result<String, jsonwebtoken::errors::Error> {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + ttl,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: user.id().into(),
};
claims.to_jwt(&self.jwt_secret)
}
/// Decode the [`User`]
pub fn decode_jwt(
&self,
token: &str,
) -> Result<User, jsonwebtoken::errors::Error> {
Ok(UserClaims::from_jwt(token, &self.jwt_secret)?.into())
}
}
pub type OidcClient = CoreClient<
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;
/// Represents the OIDC login state, which is stored in a cookie
#[derive(Debug, Deserialize, Serialize)]
struct OidcClientState {
pkce: String,
csrf: String,
nonce: String,
}
/// Error response returned to the client when OIDC login fails
#[derive(Debug, rocket::Responder)]
#[response(status = 401)]
pub struct LoginFailed(String);
impl From<LoginError> for LoginFailed {
fn from(e: LoginError) -> Self {
error!("Login failed: {}", e);
Self(e.to_string())
}
}
/// Create an openidconnect Client instance from the given Config
pub async fn get_oidc_client(
discovery_url: String,
client_id: String,
client_secret: Option<String>,
callback_url: String,
http_client: &reqwest::Client,
) -> Result<OidcClient, OidcError> {
let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(discovery_url)?,
http_client,
)
.await?;
Ok(CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(client_id),
client_secret.map(ClientSecret::new),
)
.set_redirect_uri(RedirectUrl::new(callback_url)?))
}
/// Load a secret from the specified path
pub fn load_secret<P: AsRef<Path>>(path: P) -> Result<Vec<u8>, std::io::Error> {
let mut secret = vec![];
let mut f = std::fs::File::open(path)?;
f.read_to_end(&mut secret)?;
Ok(secret)
}
/// Extract a JWT from the Authorization HTTP request header
fn extract_token(
secret: &[u8],
header: &str,
) -> Result<UserClaims, &'static str> {
let mut iter = header.split_ascii_whitespace();
let scheme = iter.next();
let token = iter.next();
if token.is_some() && Some("Bearer") != scheme {
return Err("Unsupported authorization scheme");
}
if let Some(token) = token {
Ok(UserClaims::from_jwt(token, secret).map_err(|e| {
error!("Failed to decode auth token: {}", e);
"Invalid token"
})?)
} else {
Err("Invalid Authorization header")
}
}
/// Get the authorization code from OIDC callback parameters
///
/// If the callback indicates an authorization failure, or the OAuth2
/// state is invalid, a [`LoginError`] is returned.
fn get_auth_code(
csrf_token: &str,
params: &HashMap<String, String>,
) -> Result<String, LoginError> {
let state = params.get("state").ok_or(LoginError::MissingStateParam)?;
if state != csrf_token {
return Err(LoginError::InvalidCsrfState);
}
if let Some(error) = params.get("error") {
let msg = if let Some(err_desc) = params.get("error_description") {
format!("Error handling OIDC callback ({}): {}", error, err_desc)
} else {
format!("Error handling OIDC callback: {}", error)
};
error!("{}", msg);
return Err(LoginError::IdpError(msg));
}
Ok(params
.get("code")
.ok_or(LoginError::MissingAuthCode)?
.to_string())
}
/// Exchange an OIDC Authorization code for an Identity Token
async fn exchange_code(
code: String,
state: OidcClientState,
ctx: &Context,
) -> Result<CoreIdTokenClaims, LoginError> {
let pkce_verifier = PkceCodeVerifier::new(state.pkce);
debug!("Exchanging authorization code for access token");
let http_client = &ctx.oidc_http_client;
let oidc = ctx.oidc().await?;
let token_response = oidc
.exchange_code(AuthorizationCode::new(code))?
.set_pkce_verifier(pkce_verifier)
.request_async(http_client)
.await
.map_err(LoginError::TokenRequestError)?;
debug!(
"Received response token type {:?}",
token_response.token_type()
);
debug!("Access token: {}", token_response.access_token().secret());
trace!("Token response: {:?}", token_response);
let id_token = token_response
.id_token()
.ok_or(LoginError::MissingIdToken)?;
let id_token_verifier = oidc.id_token_verifier();
let claims = id_token
.claims(&oidc.id_token_verifier(), &Nonce::new(state.nonce))?;
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
id_token.signing_alg()?,
id_token.signing_key(&id_token_verifier)?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
return Err(LoginError::InvalidAccessToken);
}
}
Ok(claims.clone())
}
/// Initiate OpenID Connect login flow
#[rocket::get("/login")]
pub async fn oidc_login(
ctx: &State<Context>,
cookies: &CookieJar<'_>,
config: &RocketConfig,
) -> Result<Redirect, LoginFailed> {
let (pkce_challenge, pkce_verifier) =
PkceCodeChallenge::new_random_sha256();
trace!("PKCE: {:?} {:?}", pkce_challenge, pkce_verifier);
let oidc = ctx.oidc().await.map_err(LoginError::from)?;
let (auth_url, csrf_token, nonce) = oidc
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.set_pkce_challenge(pkce_challenge)
.add_scope(Scope::new("openid".into()))
.add_scope(Scope::new("profile".into()))
.add_scope(Scope::new("email".into()))
.add_scope(Scope::new("groups".into()))
.url();
trace!(
"CSRF token: {}, nonce: {}",
csrf_token.secret(),
nonce.secret()
);
let state = OidcClientState {
pkce: pkce_verifier.into_secret(),
csrf: csrf_token.into_secret(),
nonce: nonce.secret().to_string(),
};
match rocket::serde::json::to_string(&state) {
Ok(s) => {
cookies.add_private(
Cookie::build(("oidc.state", s))
.secure(config.profile != RocketConfig::DEBUG_PROFILE)
.same_site(
if config.profile == RocketConfig::DEBUG_PROFILE {
SameSite::Lax
} else {
SameSite::Strict
},
)
.http_only(true)
.expires(None)
.build(),
);
},
Err(e) => {
error!("Failed to serialize OIDC client state: {}", e);
},
};
Ok(Redirect::to(auth_url.to_string()))
}
/// Handle OpenID Connect authorization callback
#[rocket::get("/oidc-callback?<params..>")]
pub async fn oidc_callback(
params: HashMap<String, String>,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
config: &State<Config>,
) -> Result<Redirect, LoginFailed> {
trace!("{:?}", params);
let state = cookies
.get_private("oidc.state")
.inspect(|c| cookies.remove(c.clone()))
.ok_or(LoginError::MissingClientState)?;
let state: OidcClientState = rocket::serde::json::from_str(state.value())
.map_err(|e| {
error!("Failed to parse OIDC client state: {}", e);
LoginError::InvalidClientState
})?;
let code = get_auth_code(&state.csrf, &params)?;
trace!("Got authorization code: {}", code);
let claims = exchange_code(code, state, ctx).await?;
let user = User {
id: claims.subject().to_string(),
};
let expires = OffsetDateTime::now_utc()
+ Duration::seconds(config.auth.login_ttl.try_into().unwrap_or_else(
|e| {
warn!("Invalid login TTL value: {}", e);
i64::MAX
},
));
cookies.add(
Cookie::build((
"auth.token",
ctx.make_jwt(&user, config.auth.login_ttl)
.map_err(LoginError::from)?,
))
.secure(true)
.http_only(true)
.expires(expires)
.build(),
);
Ok(Redirect::to("/"))
}
#[cfg(test)]
mod test {
use super::*;
static SECRET: &[u8; 32] = b"6gmrLQ4PGjeUpT3Xs48thx9Cu6XE5pgD";
fn now() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn make_token(claims: &UserClaims) -> String {
claims.to_jwt(SECRET).unwrap()
}
#[test]
fn test_extract_token() {
let now = now();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let header = format!("Bearer {}", jwt);
let claims = extract_token(SECRET, &header).unwrap();
assert_eq!(claims, claims);
}
#[test]
fn test_extract_token_expired() {
let now = now() - 600;
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let header = format!("Bearer {}", jwt);
let err = extract_token(SECRET, &header).unwrap_err();
assert_eq!(err, "Invalid token");
}
#[test]
fn test_extract_token_future() {
let now = now() + 600;
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let header = format!("Bearer {}", jwt);
let err = extract_token(SECRET, &header).unwrap_err();
assert_eq!(err, "Invalid token");
}
#[test]
fn test_extract_token_bad_issuer() {
let now = now();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: "mallory".into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let header = format!("Bearer {}", jwt);
let err = extract_token(SECRET, &header).unwrap_err();
assert_eq!(err, "Invalid token");
}
#[test]
fn test_extract_token_bad_aud() {
let now = now();
let claims = UserClaims {
aud: "mallory".into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let header = format!("Bearer {}", jwt);
let err = extract_token(SECRET, &header).unwrap_err();
assert_eq!(err, "Invalid token");
}
#[test]
fn test_malformed_header() {
let now = now();
let claims = UserClaims {
aud: "mallory".into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = make_token(&claims);
let err = extract_token(SECRET, &jwt).unwrap_err();
assert_eq!(err, "Invalid Authorization header");
}
#[test]
fn test_unsupported_auth_scheme() {
let header = "Basic dXNlcm5hbWU6cGFzc3dvcmQ=";
let err = extract_token(SECRET, header).unwrap_err();
assert_eq!(err, "Unsupported authorization scheme");
}
#[test]
fn test_userclaims_jwt_round_trip() {
let now = now();
let claims = UserClaims {
aud: env!("CARGO_PKG_NAME").into(),
exp: now + 60,
iat: now,
iss: env!("CARGO_PKG_NAME").into(),
nbf: now - 60,
sub: "test1".into(),
};
let jwt = claims.to_jwt(SECRET).unwrap();
let claims1 = UserClaims::from_jwt(&jwt, SECRET).unwrap();
assert_eq!(claims, claims1);
}
}

View File

@@ -1,3 +1,5 @@
use std::path::PathBuf;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
@@ -9,11 +11,32 @@ pub struct MeilisearchConfig {
pub index: String,
}
#[derive(Debug, Deserialize)]
pub struct AuthConfig {
pub jwt_secret: PathBuf,
#[serde(default = "default_login_ttl")]
pub login_ttl: u64,
}
#[derive(Clone, Debug, Deserialize)]
pub struct OidcConfig {
pub discovery_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub callback_url: String,
}
#[derive(Debug, Deserialize)]
pub struct Config {
pub auth: AuthConfig,
pub oidc: OidcConfig,
pub meilisearch: MeilisearchConfig,
}
fn default_index() -> String {
env!("CARGO_PKG_NAME").into()
}
fn default_login_ttl() -> u64 {
604800
}

68
src/error.rs Normal file
View File

@@ -0,0 +1,68 @@
use openidconnect::core::CoreErrorResponseType;
use openidconnect::url::ParseError;
use openidconnect::{
ClaimsVerificationError, ConfigurationError, DiscoveryError,
HttpClientError, RequestTokenError, SignatureVerificationError,
SigningError, StandardErrorResponse,
};
#[derive(Debug, thiserror::Error)]
pub enum OidcError {
#[error("Invalid URL {0}")]
Url(#[from] ParseError),
#[error("OIDC Discovery failed: {0}")]
Discovery(#[from] DiscoveryError<HttpClientError<reqwest::Error>>),
}
#[derive(Debug, thiserror::Error)]
pub enum LoginError {
#[error("Missing state parameter")]
MissingStateParam,
#[error("Invalid state parameter")]
InvalidCsrfState,
#[error("Missing OIDC client state cookie")]
MissingClientState,
#[error("Invalid OIDC client state cookie")]
InvalidClientState,
#[error("{0}")]
IdpError(String),
#[error("Missing OAuth2 authorization code")]
MissingAuthCode,
#[error("OIDC configuration error: {0}")]
ConfigurationError(#[from] ConfigurationError),
#[error("Token request error: {0}")]
TokenRequestError(
RequestTokenError<
HttpClientError<reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>,
),
#[error("Server did not return an ID token")]
MissingIdToken,
#[error("Server returned an invalid ID token")]
InvalidIdToken,
#[error("Invalid token claims: {0}")]
ClaimsVerificationError(#[from] ClaimsVerificationError),
#[error("Signature verification error: {0}")]
SignatureVerificationError(#[from] SignatureVerificationError),
#[error("Token signature error: {0}")]
SigningError(#[from] SigningError),
#[error("Invalid access token")]
InvalidAccessToken,
#[error("JWT serialization error: {0}")]
JwtError(#[from] jsonwebtoken::errors::Error),
#[error("Invalid OIDC configuration: {0}")]
Oidc(#[from] OidcError),
}
#[derive(Debug, thiserror::Error)]
pub enum InitError {
#[error("Meilisearch error: {0}")]
Meilisearch(#[from] crate::meilisearch::Error),
#[error("Failed to load JWT secret: {0}")]
LoadJwtSecret(std::io::Error),
#[error("Failed to load OIDC client secret: {0}")]
LoadOidcSecret(std::io::Error),
#[error("Failed to initialize HTTP client: {0}")]
ReqwestError(#[from] reqwest::Error)
}

109
src/lib.rs Normal file
View File

@@ -0,0 +1,109 @@
pub mod auth;
pub mod config;
mod error;
pub mod meilisearch;
pub mod page;
use meilisearch_sdk::client::Client as MeilisearchClient;
use rocket::fairing::{self, AdHoc};
use rocket::Rocket;
use tracing::error;
use config::Config;
pub use error::InitError;
pub struct Context {
client: MeilisearchClient,
jwt_secret: Vec<u8>,
oidc: config::OidcConfig,
oidc_client_secret: Option<String>,
oidc_http_client: reqwest::Client,
}
impl Context {
pub async fn init(config: &Config) -> Result<Self, InitError> {
let client = MeilisearchClient::try_from(config)?;
let jwt_secret = auth::load_secret(&config.auth.jwt_secret)
.map_err(InitError::LoadJwtSecret)?;
let oidc = config.oidc.clone();
let oidc_client_secret = match &config.oidc.client_secret {
Some(p) => match auth::load_secret(p) {
Ok(s) => Some(String::from_utf8_lossy(&s).trim().to_string()),
Err(e) => return Err(InitError::LoadOidcSecret(e)),
},
None => None,
};
let oidc_http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()?;
Ok(Self {
client,
jwt_secret,
oidc,
oidc_client_secret,
oidc_http_client,
})
}
pub async fn oidc(&self) -> Result<auth::OidcClient, error::OidcError> {
auth::get_oidc_client(
self.oidc.discovery_url.clone(),
self.oidc.client_id.clone(),
self.oidc_client_secret.clone(),
self.oidc.callback_url.clone(),
&self.oidc_http_client,
)
.await
}
}
/// Initialize the application context
async fn init_context(rocket: Rocket<rocket::Build>) -> fairing::Result {
let config: &Config = rocket.state().unwrap();
let ctx = match Context::init(config).await {
Ok(c) => c,
Err(e) => {
eprintln!("Could not initialize application context: {}", e);
return Err(rocket);
},
};
// Use the JWT secret as the Rocket secret key, for encrypted cookies
let figment = rocket
.figment()
.clone()
.join(("secret_key", &ctx.jwt_secret));
Ok(rocket.configure(figment).manage(ctx))
}
/// Set up Meilisearch
async fn meilisearch_setup(rocket: Rocket<rocket::Build>) -> fairing::Result {
let config: &Config = rocket.state().unwrap();
let ctx: &Context = match rocket.state() {
Some(c) => c,
None => return Err(rocket),
};
let client = &ctx.client;
if let Err(e) =
meilisearch::ensure_index(client, &config.meilisearch.index).await
{
error!("Failed to create Meilisearch index: {}", e);
Err(rocket)
} else {
Ok(rocket)
}
}
pub fn rocket() -> Rocket<rocket::Build> {
rocket::build()
.mount(
"/",
rocket::routes![
auth::oidc_callback,
auth::oidc_login,
page::post_page
],
)
.attach(AdHoc::config::<Config>())
.attach(AdHoc::try_on_ignite("Initialize context", init_context))
.attach(AdHoc::try_on_ignite("Meilisearch Setup", meilisearch_setup))
}

View File

@@ -1,88 +1,9 @@
mod config;
mod meilisearch;
mod page;
use meilisearch_sdk::client::Client as MeilisearchClient;
use rocket::fairing::{self, AdHoc};
use rocket::form::Form;
use rocket::serde::json::Json;
use rocket::{Rocket, State};
use tracing::error;
use config::Config;
struct Context {
#[allow(dead_code)]
config: Config,
client: MeilisearchClient,
}
#[derive(Debug, thiserror::Error)]
enum InitError {
#[error("Meilisearch error: {0}")]
Meilisearch(#[from] meilisearch::Error),
}
impl Context {
pub fn init(config: Config) -> Result<Self, InitError> {
let client = MeilisearchClient::try_from(&config)?;
Ok(Self { config, client })
}
}
/// Save page form
#[derive(rocket::FromForm)]
pub struct SavePageForm {
/// Page URL
url: String,
/// Page content (SingleFile HTML)
data: String,
}
/// Save a visited page in SingleFile format
#[rocket::post("/save", data = "<form>")]
async fn save_page(
form: Form<SavePageForm>,
ctx: &State<Context>,
) -> Result<Json<page::Page>, String> {
match page::save_page(&form.url, &form.data, ctx).await {
Ok(p) => Ok(Json(p)),
Err(e) => {
error!("Failed to save page: {}", e);
Err(e.to_string())
},
}
}
/// Set up Meilisearch
async fn meilisearch_setup(rocket: Rocket<rocket::Build>) -> fairing::Result {
let ctx: &Context = rocket.state().unwrap();
let client = &ctx.client;
let config = &ctx.config;
if let Err(e) =
meilisearch::ensure_index(client, &config.meilisearch.index).await
{
error!("Failed to create Meilisearch index: {}", e);
Err(rocket)
} else {
Ok(rocket)
}
}
#[rocket::launch]
async fn rocket() -> _ {
fn rocket() -> _ {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_writer(std::io::stderr)
.init();
let rocket = rocket::build();
let config: Config = rocket.figment().extract().unwrap();
let ctx = Context::init(config).unwrap();
rocket
.manage(ctx)
.mount("/", rocket::routes![save_page])
.attach(AdHoc::try_on_ignite("Meilisearch Setup", meilisearch_setup))
seensite::rocket()
}

View File

@@ -20,7 +20,7 @@ impl TryFrom<&Config> for Client {
Some(t) => Some(std::fs::read_to_string(t).map_err(Error::Token)?),
None => None,
};
Ok(Client::new(&config.meilisearch.url, token)?)
Ok(Client::new(&config.meilisearch.url, token.as_deref().map(str::trim))?)
}
}

View File

@@ -4,8 +4,14 @@ use html5ever::tendril::TendrilSink;
use markup5ever_rcdom::{Handle, NodeData, RcDom};
use meilisearch_sdk::errors::Error;
use rand::Rng;
use rocket::form::Form;
use rocket::serde::json::Json;
use rocket::State;
use serde::Serialize;
use tracing::{debug, error, event, span, Level};
use crate::auth::User;
use crate::config::Config;
use crate::Context;
static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
@@ -15,6 +21,8 @@ static ID_CHARSET: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
pub struct Page {
/// Unique saved page ID
id: String,
/// User ID of page owner
user_id: String,
/// Visit timestamp
timestamp: DateTime<Utc>,
/// Page URL
@@ -25,22 +33,55 @@ pub struct Page {
data: String,
}
/// Save page form
#[derive(rocket::FromForm)]
pub struct SavePageForm {
/// Page URL
url: String,
/// Page content (SingleFile HTML)
data: String,
}
/// Save a visited page in SingleFile format
#[rocket::post("/save", data = "<form>")]
pub async fn post_page(
user: User,
form: Form<SavePageForm>,
ctx: &State<Context>,
config: &State<Config>,
) -> Result<Json<Page>, String> {
match save_page(&form.url, &form.data, ctx, config, user.id()).await {
Ok(p) => Ok(Json(p)),
Err(e) => {
error!("Failed to save page: {}", e);
Err(e.to_string())
},
}
}
/// Save the page
pub async fn save_page(
url: &str,
data: &str,
ctx: &Context,
config: &Config,
user: &str,
) -> Result<Page, Error> {
let client = &ctx.client;
let index = client.get_index(&ctx.config.meilisearch.index).await?;
let span = span!(Level::INFO, "save_page", url = url, user = user);
let _guard = span.enter();
let index_name = &config.meilisearch.index;
debug!("Saving page in Meilisearch index {}", index_name);
let index = ctx.client.get_index(index_name).await?;
let doc = Page {
id: gen_id(),
user_id: user.into(),
timestamp: Utc::now(),
url: url.into(),
title: extract_title(data),
data: data.into(),
};
index.add_or_replace(&[doc.clone()], Some("id")).await?;
event!(Level::INFO, "Saved page {}", doc.id);
Ok(doc)
}

86
tests/integration/auth.rs Normal file
View File

@@ -0,0 +1,86 @@
use rocket::http::Status;
use rocket::local::asynchronous::Client;
use rocket::uri;
use scraper::{Html, Selector};
use tracing::debug;
use seensite::auth::*;
use seensite::Context;
#[rocket::async_test]
async fn test_login() {
super::setup();
let client = Client::tracked(seensite::rocket()).await.unwrap();
let dex = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.cookie_store(true)
.build()
.unwrap();
let ctx: &Context = client.rocket().state().unwrap();
// First, initiate the login process
let req = client.get(uri![oidc_login]);
let res = req.dispatch().await;
assert_eq!(res.status(), Status::SeeOther);
let mut location = res.headers().get_one("Location").unwrap().to_string();
// Next, follow the redirect URL provided by the login page.
// This will redirect again.
let res = loop {
debug!("Redirect: {}", location);
let res = dex.get(location).send().await.unwrap();
if res.status() == reqwest::StatusCode::FOUND {
let base_url = res.url().clone();
location = base_url
.join(res.headers().get("Location").unwrap().to_str().unwrap())
.unwrap()
.to_string();
continue;
}
break res;
};
// After all the redirects, we end up on the IdP login form.
assert_eq!(res.status(), reqwest::StatusCode::OK);
// Obtain the login form target
let base_url = res.url().clone();
let body = res.text().await;
let doc = Html::parse_fragment(&body.unwrap());
let sel = Selector::parse("form").unwrap();
let form = doc.select(&sel).next().unwrap();
let action = form.attr("action").unwrap();
let url = base_url.join(action).unwrap();
// Post the user credentials to the IdP login form
let res = dex
.post(url)
.form(&[("login", "user@example.com"), ("password", "password")])
.send()
.await
.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::SEE_OTHER);
// The result of the IdP login form submission is another redirect
// to the OIDC callback of our application.
let location = reqwest::Url::parse(
res.headers().get("Location").unwrap().to_str().unwrap(),
)
.unwrap();
let callback =
format!("{}?{}", location.path(), location.query().unwrap());
debug!("Callback: {}", callback);
// Finally, make the callback request to finish the login process.
let res = client.get(callback).dispatch().await;
assert_eq!(res.status(), Status::SeeOther);
let location = res.headers().get_one("Location").unwrap().to_string();
assert_eq!(location, "/");
let cookie = res.cookies().get("auth.token").unwrap();
debug!("Cookie: {:?}", cookie);
// Check to ensure the cookie contains a valid token
let user = ctx.decode_jwt(cookie.value()).unwrap();
debug!("User: {:?}", user);
assert!(!user.id().is_empty());
}

22
tests/integration/main.rs Normal file
View File

@@ -0,0 +1,22 @@
mod auth;
mod page;
use std::sync::LazyLock;
static SETUP: LazyLock<()> = LazyLock::new(|| {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::new(
concat!(
env!("CARGO_PKG_NAME"),
"=trace,",
"seensite=trace,",
"debug",
)
))
.with_test_writer()
.init();
});
fn setup() {
LazyLock::force(&SETUP);
}

65
tests/integration/page.rs Normal file
View File

@@ -0,0 +1,65 @@
use form_urlencoded::Serializer;
use rocket::http::Status;
use rocket::http::{ContentType, Header};
use rocket::local::blocking::Client;
use rocket::serde::json::Value;
use rocket::uri;
use seensite::auth::User;
use seensite::page::*;
use seensite::Context;
static TEST_URL: &str = r"http://example.org/page1.html";
static TEST_HTML: &str = r#"<!DOCTYPE html>
<html>
<head>
<title>Example Page</title>
</head>
<body>
<h1>Example Page</title>
<p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec malesuada,
tellus eu fringilla finibus, turpis sapien faucibus elit, a fringilla dolor
urna volutpat dui. Curabitur eget dui aliquet, gravida velit tempor, porta
ipsum. Donec finibus orci quis velit tincidunt placerat. Aliquam erat volutpat.
Nullam id nisl odio. Praesent egestas fringilla ultricies. Aenean blandit
lectus mauris, quis auctor ipsum porttitor quis. Vivamus egestas cursus erat,
et egestas diam volutpat eu. Vestibulum imperdiet purus ac turpis sodales, sit
amet auctor risus lacinia. Duis feugiat lobortis orci quis sagittis.</p>
</html>
"#;
#[test]
fn test_post_page() {
super::setup();
let client = Client::tracked(seensite::rocket()).unwrap();
let ctx: &Context = client.rocket().state().unwrap();
let user = User::new("test1");
let token = ctx.make_jwt(&user, 60).unwrap();
let data = Serializer::new(String::new())
.append_pair("url", TEST_URL)
.append_pair("data", TEST_HTML)
.finish();
let req = client
.post(uri![post_page])
.header(ContentType::Form)
.header(Header::new("Authorization", format!("Bearer {}", token)))
.body(&data);
let res = req.dispatch();
assert_eq!(res.status(), Status::Ok);
let page = res.into_json::<Value>().unwrap();
assert_eq!(page.get("title").unwrap().as_str().unwrap(), "Example Page");
}
#[test]
fn test_post_page_unauth() {
super::setup();
let client = Client::tracked(seensite::rocket()).unwrap();
let data = Serializer::new(String::new())
.append_pair("url", TEST_URL)
.append_pair("data", TEST_HTML)
.finish();
let req = client.post(uri![post_page]).body(&data);
let res = req.dispatch();
assert_eq!(res.status(), Status::Unauthorized);
}