Add support for trusted redirects.
parent
45f0aa3d81
commit
4674fb3a39
|
@ -15,7 +15,8 @@ readme = "README.md"
|
||||||
serde = { version = "1.0", features = ["derive"]}
|
serde = { version = "1.0", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
chrono = { version = "0.4", features = ["serde"]}
|
chrono = { version = "0.4", features = ["serde"]}
|
||||||
reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
|
#reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
|
||||||
|
reqwest = { git = "https://github.com/stalwartlabs/reqwest.git", default-features = false, features = ["stream", "rustls-tls"]}
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
base64 = "0.13"
|
base64 = "0.13"
|
||||||
|
|
|
@ -39,6 +39,7 @@ impl Client {
|
||||||
Client::handle_error(
|
Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
.timeout(Duration::from_millis(self.timeout()))
|
.timeout(Duration::from_millis(self.timeout()))
|
||||||
|
.redirect(self.redirect_policy())
|
||||||
.default_headers(headers)
|
.default_headers(headers)
|
||||||
.build()?
|
.build()?
|
||||||
.get(download_url)
|
.get(download_url)
|
||||||
|
|
|
@ -48,6 +48,7 @@ impl Client {
|
||||||
&Client::handle_error(
|
&Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
.timeout(Duration::from_millis(self.timeout()))
|
.timeout(Duration::from_millis(self.timeout()))
|
||||||
|
.redirect(self.redirect_policy())
|
||||||
.default_headers(self.headers().clone())
|
.default_headers(self.headers().clone())
|
||||||
.build()?
|
.build()?
|
||||||
.post(upload_url)
|
.post(upload_url)
|
||||||
|
|
|
@ -6,9 +6,10 @@ use std::{
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use ahash::AHashSet;
|
||||||
use reqwest::{
|
use reqwest::{
|
||||||
header::{self},
|
header::{self},
|
||||||
Response,
|
redirect, Response,
|
||||||
};
|
};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
|
|
||||||
|
@ -36,6 +37,7 @@ pub struct Client {
|
||||||
session_url: String,
|
session_url: String,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
session_updated: AtomicBool,
|
session_updated: AtomicBool,
|
||||||
|
trusted_hosts: Arc<AHashSet<String>>,
|
||||||
|
|
||||||
upload_url: Vec<URLPart<blob::URLParameter>>,
|
upload_url: Vec<URLPart<blob::URLParameter>>,
|
||||||
download_url: Vec<URLPart<blob::URLParameter>>,
|
download_url: Vec<URLPart<blob::URLParameter>>,
|
||||||
|
@ -53,6 +55,22 @@ pub struct Client {
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> {
|
pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> {
|
||||||
|
Self::connect_(url, credentials, None::<Vec<String>>).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_with_trusted(
|
||||||
|
url: &str,
|
||||||
|
credentials: impl Into<Credentials>,
|
||||||
|
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
|
||||||
|
) -> crate::Result<Self> {
|
||||||
|
Self::connect_(url, credentials, trusted_hosts.into()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_(
|
||||||
|
url: &str,
|
||||||
|
credentials: impl Into<Credentials>,
|
||||||
|
trusted_hosts: Option<impl IntoIterator<Item = impl Into<String>>>,
|
||||||
|
) -> crate::Result<Self> {
|
||||||
let authorization = match credentials.into() {
|
let authorization = match credentials.into() {
|
||||||
Credentials::Basic(s) => format!("Basic {}", s),
|
Credentials::Basic(s) => format!("Basic {}", s),
|
||||||
Credentials::Bearer(s) => format!("Bearer {}", s),
|
Credentials::Bearer(s) => format!("Bearer {}", s),
|
||||||
|
@ -67,10 +85,31 @@ impl Client {
|
||||||
header::HeaderValue::from_str(&authorization).unwrap(),
|
header::HeaderValue::from_str(&authorization).unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let trusted_hosts = Arc::new(
|
||||||
|
trusted_hosts
|
||||||
|
.map(|hosts| hosts.into_iter().map(|h| h.into()).collect::<AHashSet<_>>())
|
||||||
|
.unwrap_or_default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let trusted_hosts_ = trusted_hosts.clone();
|
||||||
let session: Session = serde_json::from_slice(
|
let session: Session = serde_json::from_slice(
|
||||||
&Client::handle_error(
|
&Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
||||||
|
.redirect(redirect::Policy::custom(move |attempt| {
|
||||||
|
if attempt.previous().len() > 5 {
|
||||||
|
attempt.error("Too many redirects.")
|
||||||
|
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
|
||||||
|
{
|
||||||
|
attempt.follow_trusted()
|
||||||
|
} else {
|
||||||
|
let message = format!(
|
||||||
|
"Aborting redirect request to unknown host '{}'.",
|
||||||
|
attempt.url().host_str().unwrap_or("")
|
||||||
|
);
|
||||||
|
attempt.error(message)
|
||||||
|
}
|
||||||
|
}))
|
||||||
.default_headers(headers.clone())
|
.default_headers(headers.clone())
|
||||||
.build()?
|
.build()?
|
||||||
.get(url)
|
.get(url)
|
||||||
|
@ -101,6 +140,7 @@ impl Client {
|
||||||
session: parking_lot::Mutex::new(Arc::new(session)),
|
session: parking_lot::Mutex::new(Arc::new(session)),
|
||||||
session_url: url.to_string(),
|
session_url: url.to_string(),
|
||||||
session_updated: true.into(),
|
session_updated: true.into(),
|
||||||
|
trusted_hosts,
|
||||||
#[cfg(feature = "websockets")]
|
#[cfg(feature = "websockets")]
|
||||||
authorization,
|
authorization,
|
||||||
timeout: DEFAULT_TIMEOUT_MS,
|
timeout: DEFAULT_TIMEOUT_MS,
|
||||||
|
@ -116,6 +156,14 @@ impl Client {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_trusted_hosts(
|
||||||
|
&mut self,
|
||||||
|
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
|
||||||
|
) -> &mut Self {
|
||||||
|
self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn timeout(&self) -> u64 {
|
pub fn timeout(&self) -> u64 {
|
||||||
self.timeout
|
self.timeout
|
||||||
}
|
}
|
||||||
|
@ -132,6 +180,24 @@ impl Client {
|
||||||
&self.headers
|
&self.headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn redirect_policy(&self) -> redirect::Policy {
|
||||||
|
let trusted_hosts = self.trusted_hosts.clone();
|
||||||
|
redirect::Policy::custom(move |attempt| {
|
||||||
|
if attempt.previous().len() > 5 {
|
||||||
|
attempt.error("Too many redirects.")
|
||||||
|
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
|
||||||
|
{
|
||||||
|
attempt.follow_trusted()
|
||||||
|
} else {
|
||||||
|
let message = format!(
|
||||||
|
"Aborting redirect request to unknown host '{}'.",
|
||||||
|
attempt.url().host_str().unwrap_or("")
|
||||||
|
);
|
||||||
|
attempt.error(message)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn send<R>(
|
pub async fn send<R>(
|
||||||
&self,
|
&self,
|
||||||
request: &request::Request<'_>,
|
request: &request::Request<'_>,
|
||||||
|
@ -142,6 +208,7 @@ impl Client {
|
||||||
let response: response::Response<R> = serde_json::from_slice(
|
let response: response::Response<R> = serde_json::from_slice(
|
||||||
&Client::handle_error(
|
&Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
|
.redirect(self.redirect_policy())
|
||||||
.timeout(Duration::from_millis(self.timeout))
|
.timeout(Duration::from_millis(self.timeout))
|
||||||
.default_headers(self.headers.clone())
|
.default_headers(self.headers.clone())
|
||||||
.build()?
|
.build()?
|
||||||
|
@ -167,6 +234,7 @@ impl Client {
|
||||||
&Client::handle_error(
|
&Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
||||||
|
.redirect(self.redirect_policy())
|
||||||
.default_headers(self.headers.clone())
|
.default_headers(self.headers.clone())
|
||||||
.build()?
|
.build()?
|
||||||
.get(&self.session_url)
|
.get(&self.session_url)
|
||||||
|
|
|
@ -63,6 +63,7 @@ impl Client {
|
||||||
let mut stream = Client::handle_error(
|
let mut stream = Client::handle_error(
|
||||||
reqwest::Client::builder()
|
reqwest::Client::builder()
|
||||||
.connect_timeout(Duration::from_millis(self.timeout()))
|
.connect_timeout(Duration::from_millis(self.timeout()))
|
||||||
|
.redirect(self.redirect_policy())
|
||||||
.default_headers(headers)
|
.default_headers(headers)
|
||||||
.build()?
|
.build()?
|
||||||
.get(event_source_url)
|
.get(event_source_url)
|
||||||
|
|
Loading…
Reference in New Issue