diff --git a/src/marionette/error.rs b/src/marionette/error.rs index c089467..065bea0 100644 --- a/src/marionette/error.rs +++ b/src/marionette/error.rs @@ -1,5 +1,5 @@ -use std::str::Utf8Error; use std::num::ParseIntError; +use std::str::Utf8Error; #[derive(Debug)] pub enum MessageError { @@ -27,32 +27,26 @@ impl From for MessageError { } #[derive(Debug)] -pub enum HandshakeError { +pub enum ConnectionError { + Message(MessageError), Io(std::io::Error), - Parse(ParseIntError), - Utf8(Utf8Error), Json(serde_json::Error), } -impl From for HandshakeError { +impl From for ConnectionError { fn from(e: MessageError) -> Self { - match e { - MessageError::Io(e) => Self::Io(e), - MessageError::Parse(e) => Self::Parse(e), - MessageError::Utf8(e) => Self::Utf8(e), - } + Self::Message(e) } } -impl From for HandshakeError { +impl From for ConnectionError { fn from(e: std::io::Error) -> Self { Self::Io(e) } } -impl From for HandshakeError { +impl From for ConnectionError { fn from(e: serde_json::Error) -> Self { Self::Json(e) } } - diff --git a/src/marionette/mod.rs b/src/marionette/mod.rs index 1ce8e04..1fa0716 100644 --- a/src/marionette/mod.rs +++ b/src/marionette/mod.rs @@ -1,78 +1,146 @@ pub mod error; pub mod message; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use std::time::Instant; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufStream}; +use tokio::io::{ + AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, + BufWriter, +}; +use tokio::net::tcp::OwnedWriteHalf; use tokio::net::{TcpStream, ToSocketAddrs}; -use tracing::{debug, trace}; +use tokio::sync::oneshot; +use tracing::{debug, error, trace, warn}; -pub use error::{HandshakeError, MessageError}; +pub use error::{ConnectionError, MessageError}; use message::{Command, Hello, NewSessionParams, NewSessionResponse}; #[derive(Debug, Deserialize, Serialize)] -struct Message(u8, u32, Option, serde_json::Value); +struct Message(u8, u32, Option, Option); + +type SenderMap = HashMap>>; pub struct Marionette { ts: Instant, - stream: BufStream, + stream: BufWriter, + sender: Arc>, } impl Marionette { - pub async fn connect(addr: A) -> Result + pub async fn connect(addr: A) -> Result where A: ToSocketAddrs, { let conn = TcpStream::connect(addr).await?; - let stream = BufStream::new(conn); + let (read, write) = conn.into_split(); + let stream = BufWriter::new(write); + let mut rstream = BufReader::new(read); let ts = Instant::now(); - Ok(Self { ts, stream }) - } - - pub async fn handshake(&mut self) -> Result<(), HandshakeError> { - let buf = self.next_message().await?; + let sender = Arc::new(Mutex::new(HashMap::new())); + let buf = Self::next_message(&mut rstream).await?; let hello: Hello = serde_json::from_slice(&buf)?; debug!("Received hello: {:?}", hello); - self.send_message(Command::NewSession(NewSessionParams { - strict_file_interactability: true, - })) - .await?; - let buf = self.next_message().await?; - let msg: Message = serde_json::from_slice(&buf)?; - let res: NewSessionResponse = serde_json::from_value(msg.3)?; + Self::start_recv_loop(rstream, sender.clone()); + Ok(Self { ts, stream, sender }) + } + + pub async fn new_session( + &mut self, + ) -> Result { + let res = self + .send_message(Command::NewSession(NewSessionParams { + strict_file_interactability: true, + })) + .await? + .unwrap(); debug!("Received message: {:?}", res); - Ok(()) + Ok(res) } - async fn next_message(&mut self) -> Result, MessageError> { - let mut buf = vec![]; - self.stream.read_until(b':', &mut buf).await?; - let length: usize = - std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?; - trace!("Message length: {:?}", length); - let mut buf = vec![0; length]; - self.stream.read_exact(&mut buf[..]).await?; - trace!("Received message: {:?}", buf); - Ok(buf) - } - - async fn send_message( + pub async fn send_message( &mut self, command: Command, - ) -> Result<(), std::io::Error> { + ) -> Result, std::io::Error> + where + T: DeserializeOwned, + { let value = serde_json::to_value(command)?; let (command, params) = ( value.get("command").unwrap().as_str().unwrap().into(), - value.get("params").unwrap().clone(), + value.get("params").cloned(), ); let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32; let message = Message(0, msgid, Some(command), params); let message = serde_json::to_string(&message)?; let message = format!("{}:{}", message.len(), message); trace!("Sending message: {}", message); + let (tx, rx) = oneshot::channel(); + { + let mut sender = self.sender.lock().unwrap(); + sender.insert(msgid, tx); + } self.stream.write_all(message.as_bytes()).await?; self.stream.flush().await?; - Ok(()) + let Some(r) = rx.await.unwrap() else { + return Ok(None) + }; + Ok(serde_json::from_value(r)?) + } + + fn start_recv_loop( + mut stream: BufReader, + sender: Arc>, + ) where + T: AsyncRead + Send + Unpin + 'static, + { + tokio::spawn(async move { + loop { + let buf = match Self::next_message(&mut stream).await { + Ok(b) => b, + Err(e) => { + error!("Error receiving message: {:?}", e); + break; + } + }; + let msg: Message = match serde_json::from_slice(&buf[..]) { + Ok(m) => m, + Err(e) => { + warn!("Error parsing message: {}", e); + continue; + } + }; + let msgid = msg.1; + let value = msg.3; + let mut sender = sender.lock().unwrap(); + if let Some(s) = sender.remove(&msgid) { + if s.send(value).is_err() { + warn!("Failed to send result to caller"); + } + } else { + warn!("Got unsolicited message {} ({:?})", msgid, value); + } + } + }); + } + + async fn next_message( + stream: &mut BufReader, + ) -> Result, MessageError> + where + T: AsyncRead + Unpin, + { + let mut buf = vec![]; + stream.read_until(b':', &mut buf).await?; + let length: usize = + std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?; + trace!("Message length: {:?}", length); + let mut buf = vec![0; length]; + stream.read_exact(&mut buf[..]).await?; + trace!("Received message: {:?}", buf); + Ok(buf) } } diff --git a/src/session.rs b/src/session.rs index 1bada9f..49d9264 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,14 +1,14 @@ use tracing::{debug, info}; use crate::browser::{Browser, BrowserError}; +use crate::marionette::error::ConnectionError; use crate::marionette::Marionette; -use crate::marionette::error::HandshakeError; #[derive(Debug)] pub enum SessionError { Browser(BrowserError), Io(std::io::Error), - Handshake(HandshakeError), + Connection(ConnectionError), InvalidState(String), } @@ -24,13 +24,12 @@ impl From for SessionError { } } -impl From for SessionError { - fn from(e: HandshakeError) -> Self { - Self::Handshake(e) +impl From for SessionError { + fn from(e: ConnectionError) -> Self { + Self::Connection(e) } } - pub struct Session { browser: Browser, marionette: Marionette, @@ -48,7 +47,8 @@ impl Session { debug!("Connecting to Firefox Marionette on port {}", port); let mut marionette = Marionette::connect(("127.0.0.1", port)).await?; info!("Successfully connected to Firefox Marionette"); - marionette.handshake().await?; + let ses = marionette.new_session().await?; + debug!("Started Marionette session {}", ses.session_id); Ok(Self { browser, marionette,