Add support for trusted redirects.

main
Mauro D 2022-08-04 12:20:40 +00:00
parent 45f0aa3d81
commit 4674fb3a39
5 changed files with 74 additions and 2 deletions

View File

@ -15,7 +15,8 @@ readme = "README.md"
serde = { version = "1.0", features = ["derive"]} serde = { version = "1.0", features = ["derive"]}
serde_json = "1.0" serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"]} chrono = { version = "0.4", features = ["serde"]}
reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]} #reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
reqwest = { git = "https://github.com/stalwartlabs/reqwest.git", default-features = false, features = ["stream", "rustls-tls"]}
futures-util = "0.3" futures-util = "0.3"
async-stream = "0.3" async-stream = "0.3"
base64 = "0.13" base64 = "0.13"

View File

@ -39,6 +39,7 @@ impl Client {
Client::handle_error( Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_millis(self.timeout())) .timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(headers) .default_headers(headers)
.build()? .build()?
.get(download_url) .get(download_url)

View File

@ -48,6 +48,7 @@ impl Client {
&Client::handle_error( &Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_millis(self.timeout())) .timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(self.headers().clone()) .default_headers(self.headers().clone())
.build()? .build()?
.post(upload_url) .post(upload_url)

View File

@ -6,9 +6,10 @@ use std::{
time::Duration, time::Duration,
}; };
use ahash::AHashSet;
use reqwest::{ use reqwest::{
header::{self}, header::{self},
Response, redirect, Response,
}; };
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -36,6 +37,7 @@ pub struct Client {
session_url: String, session_url: String,
api_url: String, api_url: String,
session_updated: AtomicBool, session_updated: AtomicBool,
trusted_hosts: Arc<AHashSet<String>>,
upload_url: Vec<URLPart<blob::URLParameter>>, upload_url: Vec<URLPart<blob::URLParameter>>,
download_url: Vec<URLPart<blob::URLParameter>>, download_url: Vec<URLPart<blob::URLParameter>>,
@ -53,6 +55,22 @@ pub struct Client {
impl Client { impl Client {
pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> { pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> {
Self::connect_(url, credentials, None::<Vec<String>>).await
}
pub async fn connect_with_trusted(
url: &str,
credentials: impl Into<Credentials>,
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> crate::Result<Self> {
Self::connect_(url, credentials, trusted_hosts.into()).await
}
async fn connect_(
url: &str,
credentials: impl Into<Credentials>,
trusted_hosts: Option<impl IntoIterator<Item = impl Into<String>>>,
) -> crate::Result<Self> {
let authorization = match credentials.into() { let authorization = match credentials.into() {
Credentials::Basic(s) => format!("Basic {}", s), Credentials::Basic(s) => format!("Basic {}", s),
Credentials::Bearer(s) => format!("Bearer {}", s), Credentials::Bearer(s) => format!("Bearer {}", s),
@ -67,10 +85,31 @@ impl Client {
header::HeaderValue::from_str(&authorization).unwrap(), header::HeaderValue::from_str(&authorization).unwrap(),
); );
let trusted_hosts = Arc::new(
trusted_hosts
.map(|hosts| hosts.into_iter().map(|h| h.into()).collect::<AHashSet<_>>())
.unwrap_or_default(),
);
let trusted_hosts_ = trusted_hosts.clone();
let session: Session = serde_json::from_slice( let session: Session = serde_json::from_slice(
&Client::handle_error( &Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
.redirect(redirect::Policy::custom(move |attempt| {
if attempt.previous().len() > 5 {
attempt.error("Too many redirects.")
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
{
attempt.follow_trusted()
} else {
let message = format!(
"Aborting redirect request to unknown host '{}'.",
attempt.url().host_str().unwrap_or("")
);
attempt.error(message)
}
}))
.default_headers(headers.clone()) .default_headers(headers.clone())
.build()? .build()?
.get(url) .get(url)
@ -101,6 +140,7 @@ impl Client {
session: parking_lot::Mutex::new(Arc::new(session)), session: parking_lot::Mutex::new(Arc::new(session)),
session_url: url.to_string(), session_url: url.to_string(),
session_updated: true.into(), session_updated: true.into(),
trusted_hosts,
#[cfg(feature = "websockets")] #[cfg(feature = "websockets")]
authorization, authorization,
timeout: DEFAULT_TIMEOUT_MS, timeout: DEFAULT_TIMEOUT_MS,
@ -116,6 +156,14 @@ impl Client {
self self
} }
pub fn set_trusted_hosts(
&mut self,
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> &mut Self {
self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
self
}
pub fn timeout(&self) -> u64 { pub fn timeout(&self) -> u64 {
self.timeout self.timeout
} }
@ -132,6 +180,24 @@ impl Client {
&self.headers &self.headers
} }
pub(crate) fn redirect_policy(&self) -> redirect::Policy {
let trusted_hosts = self.trusted_hosts.clone();
redirect::Policy::custom(move |attempt| {
if attempt.previous().len() > 5 {
attempt.error("Too many redirects.")
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
{
attempt.follow_trusted()
} else {
let message = format!(
"Aborting redirect request to unknown host '{}'.",
attempt.url().host_str().unwrap_or("")
);
attempt.error(message)
}
})
}
pub async fn send<R>( pub async fn send<R>(
&self, &self,
request: &request::Request<'_>, request: &request::Request<'_>,
@ -142,6 +208,7 @@ impl Client {
let response: response::Response<R> = serde_json::from_slice( let response: response::Response<R> = serde_json::from_slice(
&Client::handle_error( &Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.redirect(self.redirect_policy())
.timeout(Duration::from_millis(self.timeout)) .timeout(Duration::from_millis(self.timeout))
.default_headers(self.headers.clone()) .default_headers(self.headers.clone())
.build()? .build()?
@ -167,6 +234,7 @@ impl Client {
&Client::handle_error( &Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
.redirect(self.redirect_policy())
.default_headers(self.headers.clone()) .default_headers(self.headers.clone())
.build()? .build()?
.get(&self.session_url) .get(&self.session_url)

View File

@ -63,6 +63,7 @@ impl Client {
let mut stream = Client::handle_error( let mut stream = Client::handle_error(
reqwest::Client::builder() reqwest::Client::builder()
.connect_timeout(Duration::from_millis(self.timeout())) .connect_timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(headers) .default_headers(headers)
.build()? .build()?
.get(event_source_url) .get(event_source_url)