diff --git a/src/marionette/error.rs b/src/marionette/error.rs index 5ef1b4d..17e019c 100644 --- a/src/marionette/error.rs +++ b/src/marionette/error.rs @@ -1,6 +1,8 @@ use std::num::ParseIntError; use std::str::Utf8Error; +use serde::Deserialize; + #[derive(Debug)] pub enum MessageError { Io(std::io::Error), @@ -90,3 +92,63 @@ impl std::error::Error for ConnectionError { } } } + +#[derive(Debug, Deserialize)] +pub struct ErrorResponse { + pub error: String, + pub message: String, + pub stacktrace: String, +} + +impl std::fmt::Display for ErrorResponse { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {}", self.error, self.message) + } +} + +impl std::error::Error for ErrorResponse {} + +#[derive(Debug)] +pub enum CommandError { + Io(std::io::Error), + Command(ErrorResponse), + Json(serde_json::Error), +} + +impl From for CommandError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +impl From for CommandError { + fn from(e: ErrorResponse) -> Self { + Self::Command(e) + } +} + +impl From for CommandError { + fn from(e: serde_json::Error) -> Self { + Self::Json(e) + } +} + +impl std::fmt::Display for CommandError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Io(e) => write!(f, "I/O error: {}", e), + Self::Command(e) => write!(f, "Marionette command error: {}", e), + Self::Json(e) => write!(f, "JSON deserialization error: {}", e), + } + } +} + +impl std::error::Error for CommandError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + Self::Command(e) => Some(e), + Self::Json(e) => Some(e), + } + } +} diff --git a/src/marionette/mod.rs b/src/marionette/mod.rs index 041d5b4..11dacf8 100644 --- a/src/marionette/mod.rs +++ b/src/marionette/mod.rs @@ -16,16 +16,26 @@ use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::sync::oneshot; use tracing::{debug, error, trace, warn}; -pub use error::{ConnectionError, MessageError}; +pub use error::{CommandError, ConnectionError, ErrorResponse, MessageError}; use message::{ Command, GetTitleResponse, Hello, NavigateParams, NewSessionParams, NewSessionResponse, }; #[derive(Debug, Deserialize, Serialize)] -struct Message(u8, u32, Option, Option); +struct Message(u8, u32, String, Option); -type SenderMap = HashMap>>; +#[derive(Debug, Deserialize, Serialize)] +struct Response( + u8, + u32, + Option, + Option, +); + +type CommandResult = Result, ErrorResponse>; + +type SenderMap = HashMap>; pub struct Marionette { ts: Instant, @@ -51,14 +61,14 @@ impl Marionette { Ok(Self { ts, stream, sender }) } - pub async fn get_title(&mut self) -> Result { + pub async fn get_title(&mut self) -> Result { let res: GetTitleResponse = self.send_message(Command::GetTitle).await?.unwrap(); debug!("Received message: {:?}", res); Ok(res.value) } - pub async fn navigate(&mut self, url: U) -> Result<(), std::io::Error> + pub async fn navigate(&mut self, url: U) -> Result<(), CommandError> where U: Into, { @@ -73,7 +83,7 @@ impl Marionette { pub async fn new_session( &mut self, - ) -> Result { + ) -> Result { let res = self .send_message(Command::NewSession(NewSessionParams { strict_file_interactability: true, @@ -87,7 +97,7 @@ impl Marionette { pub async fn send_message( &mut self, command: Command, - ) -> Result, std::io::Error> + ) -> Result, CommandError> where T: DeserializeOwned, { @@ -97,7 +107,7 @@ impl Marionette { value.get("params").cloned(), ); let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32; - let message = Message(0, msgid, Some(command), params); + let message = Message(0, msgid, command, params); let message = serde_json::to_string(&message)?; let message = format!("{}:{}", message.len(), message); trace!("Sending message: {}", message); @@ -108,10 +118,10 @@ impl Marionette { } self.stream.write_all(message.as_bytes()).await?; self.stream.flush().await?; - let Some(r) = rx.await.unwrap() else { - return Ok(None) - }; - Ok(serde_json::from_value(r)?) + match rx.await.unwrap()? { + Some(r) => Ok(serde_json::from_value(r)?), + None => Ok(None), + } } fn start_recv_loop( @@ -129,7 +139,7 @@ impl Marionette { break; } }; - let msg: Message = match serde_json::from_slice(&buf[..]) { + let msg: Response = match serde_json::from_slice(&buf[..]) { Ok(m) => m, Err(e) => { warn!("Error parsing message: {}", e); @@ -137,10 +147,23 @@ impl Marionette { } }; let msgid = msg.1; + let error = msg.2; let value = msg.3; let mut sender = sender.lock().unwrap(); if let Some(s) = sender.remove(&msgid) { - if s.send(value).is_err() { + let r; + if let Some(e) = error { + match serde_json::from_value(e) { + Ok(e) => r = Err(e), + Err(e) => { + warn!("Error parsing error response: {}", e); + continue; + } + } + } else { + r = Ok(value); + } + if s.send(r).is_err() { warn!("Failed to send result to caller"); } } else { diff --git a/src/session.rs b/src/session.rs index a3f8367..faaf1a9 100644 --- a/src/session.rs +++ b/src/session.rs @@ -4,7 +4,7 @@ use tracing::{debug, error, info, warn}; use crate::browser::{Browser, BrowserError}; use crate::config::Configuration; -use crate::marionette::error::ConnectionError; +use crate::marionette::error::{CommandError, ConnectionError}; use crate::marionette::Marionette; use crate::mqtt::{Message, MqttClient, MqttPublisher}; @@ -14,6 +14,7 @@ pub enum SessionError { Io(std::io::Error), Connection(ConnectionError), InvalidState(String), + Command(CommandError), } impl From for SessionError { @@ -34,6 +35,12 @@ impl From for SessionError { } } +impl From for SessionError { + fn from(e: CommandError) -> Self { + Self::Command(e) + } +} + impl std::fmt::Display for SessionError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { @@ -41,6 +48,7 @@ impl std::fmt::Display for SessionError { Self::Io(e) => write!(f, "I/O error: {}", e), Self::Connection(e) => write!(f, "Connection error: {}", e), Self::InvalidState(e) => write!(f, "Invalid state: {}", e), + Self::Command(e) => write!(f, "Marionette command failed: {}", e), } } } @@ -52,6 +60,7 @@ impl std::error::Error for SessionError { Self::Io(e) => Some(e), Self::Connection(e) => Some(e), Self::InvalidState(_) => None, + Self::Command(e) => Some(e), } } }