use std::{ sync::{ atomic::{AtomicBool, Ordering}, Arc, }, time::Duration, }; use ahash::AHashSet; use reqwest::{ header::{self}, redirect, Response, }; use serde::de::DeserializeOwned; use crate::{ blob, core::{ request::{self, Request}, response, session::{Session, URLPart}, }, event_source, Error, }; const DEFAULT_TIMEOUT_MS: u64 = 10 * 1000; static USER_AGENT: &str = concat!("jmap-client/", env!("CARGO_PKG_VERSION")); #[derive(Debug, PartialEq, Eq)] pub enum Credentials { Basic(String), Bearer(String), } pub struct Client { session: parking_lot::Mutex>, session_url: String, api_url: String, session_updated: AtomicBool, trusted_hosts: Arc>, upload_url: Vec>, download_url: Vec>, event_source_url: Vec>, headers: header::HeaderMap, default_account_id: String, timeout: u64, #[cfg(feature = "websockets")] pub(crate) authorization: String, #[cfg(feature = "websockets")] pub(crate) ws: tokio::sync::Mutex>, } impl Client { pub async fn connect(url: &str, credentials: impl Into) -> crate::Result { Self::connect_(url, credentials, None::>).await } pub async fn connect_with_trusted( url: &str, credentials: impl Into, trusted_hosts: impl IntoIterator>, ) -> crate::Result { Self::connect_(url, credentials, trusted_hosts.into()).await } async fn connect_( url: &str, credentials: impl Into, trusted_hosts: Option>>, ) -> crate::Result { let authorization = match credentials.into() { Credentials::Basic(s) => format!("Basic {}", s), Credentials::Bearer(s) => format!("Bearer {}", s), }; let mut headers = header::HeaderMap::new(); headers.insert( header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT), ); headers.insert( header::AUTHORIZATION, header::HeaderValue::from_str(&authorization).unwrap(), ); let trusted_hosts = Arc::new( trusted_hosts .map(|hosts| hosts.into_iter().map(|h| h.into()).collect::>()) .unwrap_or_default(), ); let trusted_hosts_ = trusted_hosts.clone(); let session: Session = serde_json::from_slice( &Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) .redirect(redirect::Policy::custom(move |attempt| { if attempt.previous().len() > 5 { attempt.error("Too many redirects.") } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) ) { attempt.follow_trusted() } else { let message = format!( "Aborting redirect request to unknown host '{}'.", attempt.url().host_str().unwrap_or("") ); attempt.error(message) } })) .default_headers(headers.clone()) .build()? .get(url) .send() .await?, ) .await? .bytes() .await?, )?; let default_account_id = session .primary_accounts() .next() .map(|a| a.1.to_string()) .unwrap_or_default(); headers.insert( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), ); Ok(Client { download_url: URLPart::parse(session.download_url())?, upload_url: URLPart::parse(session.upload_url())?, event_source_url: URLPart::parse(session.event_source_url())?, api_url: session.api_url().to_string(), session: parking_lot::Mutex::new(Arc::new(session)), session_url: url.to_string(), session_updated: true.into(), trusted_hosts, #[cfg(feature = "websockets")] authorization, timeout: DEFAULT_TIMEOUT_MS, headers, default_account_id, #[cfg(feature = "websockets")] ws: None.into(), }) } pub fn set_timeout(&mut self, timeout: u64) -> &mut Self { self.timeout = timeout; self } pub fn set_trusted_hosts( &mut self, trusted_hosts: impl IntoIterator>, ) -> &mut Self { self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect()); self } pub fn timeout(&self) -> u64 { self.timeout } pub fn session(&self) -> Arc { self.session.lock().clone() } pub fn session_url(&self) -> &str { &self.session_url } pub fn headers(&self) -> &header::HeaderMap { &self.headers } pub(crate) fn redirect_policy(&self) -> redirect::Policy { let trusted_hosts = self.trusted_hosts.clone(); redirect::Policy::custom(move |attempt| { if attempt.previous().len() > 5 { attempt.error("Too many redirects.") } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) ) { attempt.follow_trusted() } else { let message = format!( "Aborting redirect request to unknown host '{}'.", attempt.url().host_str().unwrap_or("") ); attempt.error(message) } }) } pub async fn send( &self, request: &request::Request<'_>, ) -> crate::Result> where R: DeserializeOwned, { let response: response::Response = serde_json::from_slice( &Client::handle_error( reqwest::Client::builder() .redirect(self.redirect_policy()) .timeout(Duration::from_millis(self.timeout)) .default_headers(self.headers.clone()) .build()? .post(&self.api_url) .body(serde_json::to_string(&request)?) .send() .await?, ) .await? .bytes() .await?, )?; if response.session_state() != self.session.lock().state() { self.session_updated.store(false, Ordering::Relaxed); } Ok(response) } pub async fn refresh_session(&self) -> crate::Result<()> { let session: Session = serde_json::from_slice( &Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) .redirect(self.redirect_policy()) .default_headers(self.headers.clone()) .build()? .get(&self.session_url) .send() .await?, ) .await? .bytes() .await?, )?; *self.session.lock() = Arc::new(session); self.session_updated.store(true, Ordering::Relaxed); Ok(()) } pub fn is_session_updated(&self) -> bool { self.session_updated.load(Ordering::Relaxed) } pub fn set_default_account_id(&mut self, defaul_account_id: impl Into) -> &mut Self { self.default_account_id = defaul_account_id.into(); self } pub fn default_account_id(&self) -> &str { &self.default_account_id } pub fn build(&self) -> Request<'_> { Request::new(self) } pub fn download_url(&self) -> &[URLPart] { &self.download_url } pub fn upload_url(&self) -> &[URLPart] { &self.upload_url } pub fn event_source_url(&self) -> &[URLPart] { &self.event_source_url } pub async fn handle_error(response: Response) -> crate::Result { if response.status().is_success() { Ok(response) } else if let Some(b"application/problem+json") = response .headers() .get(header::CONTENT_TYPE) .map(|h| h.as_bytes()) { Err(Error::Problem(serde_json::from_slice( &response.bytes().await?, )?)) } else { Err(Error::Server(format!("{}", response.status()))) } } } impl Credentials { pub fn basic(username: &str, password: &str) -> Self { Credentials::Basic(base64::encode(format!("{}:{}", username, password))) } pub fn bearer(token: impl Into) -> Self { Credentials::Bearer(token.into()) } } impl From<&str> for Credentials { fn from(s: &str) -> Self { Credentials::bearer(s.to_string()) } } impl From for Credentials { fn from(s: String) -> Self { Credentials::bearer(s) } } impl From<(&str, &str)> for Credentials { fn from((username, password): (&str, &str)) -> Self { Credentials::basic(username, password) } } impl From<(String, String)> for Credentials { fn from((username, password): (String, String)) -> Self { Credentials::basic(&username, &password) } } #[cfg(test)] mod tests { use crate::core::response::{Response, TaggedMethodResponse}; #[test] fn test_deserialize() { let _r: Response = serde_json::from_slice( br#"{"sessionState": "123", "methodResponses": [[ "Email/query", { "accountId": "A1", "queryState": "abcdefg", "canCalculateChanges": true, "position": 0, "total": 101, "ids": [ "msg1023", "msg223", "msg110", "msg93", "msg91", "msg38", "msg36", "msg33", "msg11", "msg1" ] }, "t0" ], [ "Email/get", { "accountId": "A1", "state": "123456", "list": [{ "id": "msg1023", "threadId": "trd194" }, { "id": "msg223", "threadId": "trd114" } ], "notFound": [] }, "t1" ], [ "Thread/get", { "accountId": "A1", "state": "123456", "list": [{ "id": "trd194", "emailIds": [ "msg1020", "msg1021", "msg1023" ] }, { "id": "trd114", "emailIds": [ "msg201", "msg223" ] } ], "notFound": [] }, "t2" ]]}"#, ) .unwrap(); //println!("{:?}", r); } }