pub mod error; pub mod message; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::Instant; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::io::{ AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter, }; use tokio::net::tcp::OwnedWriteHalf; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::sync::oneshot; use tracing::{debug, error, trace, warn}; pub use error::{ConnectionError, MessageError}; use message::{ Command, GetTitleResponse, Hello, NavigateParams, NewSessionParams, NewSessionResponse, }; #[derive(Debug, Deserialize, Serialize)] struct Message(u8, u32, Option, Option); type SenderMap = HashMap>>; pub struct Marionette { ts: Instant, stream: BufWriter, sender: Arc>, } impl Marionette { pub async fn connect(addr: A) -> Result where A: ToSocketAddrs, { let conn = TcpStream::connect(addr).await?; let (read, write) = conn.into_split(); let stream = BufWriter::new(write); let mut rstream = BufReader::new(read); let ts = Instant::now(); let sender = Arc::new(Mutex::new(HashMap::new())); let buf = Self::next_message(&mut rstream).await?; let hello: Hello = serde_json::from_slice(&buf)?; debug!("Received hello: {:?}", hello); Self::start_recv_loop(rstream, sender.clone()); Ok(Self { ts, stream, sender }) } pub async fn get_title(&mut self) -> Result { let res: GetTitleResponse = self.send_message(Command::GetTitle).await?.unwrap(); debug!("Received message: {:?}", res); Ok(res.value) } pub async fn navigate(&mut self, url: U) -> Result<(), std::io::Error> where U: Into, { let res: Option = self .send_message(Command::Navigate(NavigateParams { url: url.into(), })) .await?; debug!("Received message: {:?}", res); Ok(()) } pub async fn new_session( &mut self, ) -> Result { let res = self .send_message(Command::NewSession(NewSessionParams { strict_file_interactability: true, })) .await? .unwrap(); debug!("Received message: {:?}", res); Ok(res) } pub async fn send_message( &mut self, command: Command, ) -> Result, std::io::Error> where T: DeserializeOwned, { let value = serde_json::to_value(command)?; let (command, params) = ( value.get("command").unwrap().as_str().unwrap().into(), value.get("params").cloned(), ); let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32; let message = Message(0, msgid, Some(command), params); let message = serde_json::to_string(&message)?; let message = format!("{}:{}", message.len(), message); trace!("Sending message: {}", message); let (tx, rx) = oneshot::channel(); { let mut sender = self.sender.lock().unwrap(); sender.insert(msgid, tx); } self.stream.write_all(message.as_bytes()).await?; self.stream.flush().await?; let Some(r) = rx.await.unwrap() else { return Ok(None) }; Ok(serde_json::from_value(r)?) } fn start_recv_loop( mut stream: BufReader, sender: Arc>, ) where T: AsyncRead + Send + Unpin + 'static, { tokio::spawn(async move { loop { let buf = match Self::next_message(&mut stream).await { Ok(b) => b, Err(e) => { error!("Error receiving message: {:?}", e); break; } }; let msg: Message = match serde_json::from_slice(&buf[..]) { Ok(m) => m, Err(e) => { warn!("Error parsing message: {}", e); continue; } }; let msgid = msg.1; let value = msg.3; let mut sender = sender.lock().unwrap(); if let Some(s) = sender.remove(&msgid) { if s.send(value).is_err() { warn!("Failed to send result to caller"); } } else { warn!("Got unsolicited message {} ({:?})", msgid, value); } } }); } async fn next_message( stream: &mut BufReader, ) -> Result, MessageError> where T: AsyncRead + Unpin, { let mut buf = vec![]; stream.read_until(b':', &mut buf).await?; let length: usize = std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?; trace!("Message length: {:?}", length); let mut buf = vec![0; length]; stream.read_exact(&mut buf[..]).await?; trace!("Received message: {:?}", buf); Ok(buf) } }