mqttmarionette/src/marionette/mod.rs

170 lines
5.3 KiB
Rust

pub mod error;
pub mod message;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{
AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader,
BufWriter,
};
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn};
pub use error::{ConnectionError, MessageError};
use message::{
Command, GetTitleResponse, Hello, NavigateParams, NewSessionParams,
NewSessionResponse,
};
#[derive(Debug, Deserialize, Serialize)]
struct Message(u8, u32, Option<String>, Option<serde_json::Value>);
type SenderMap = HashMap<u32, oneshot::Sender<Option<serde_json::Value>>>;
pub struct Marionette {
ts: Instant,
stream: BufWriter<OwnedWriteHalf>,
sender: Arc<Mutex<SenderMap>>,
}
impl Marionette {
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
where
A: ToSocketAddrs,
{
let conn = TcpStream::connect(addr).await?;
let (read, write) = conn.into_split();
let stream = BufWriter::new(write);
let mut rstream = BufReader::new(read);
let ts = Instant::now();
let sender = Arc::new(Mutex::new(HashMap::new()));
let buf = Self::next_message(&mut rstream).await?;
let hello: Hello = serde_json::from_slice(&buf)?;
debug!("Received hello: {:?}", hello);
Self::start_recv_loop(rstream, sender.clone());
Ok(Self { ts, stream, sender })
}
pub async fn get_title(&mut self) -> Result<String, std::io::Error> {
let res: GetTitleResponse =
self.send_message(Command::GetTitle).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn navigate<U>(&mut self, url: U) -> Result<(), std::io::Error>
where
U: Into<String>,
{
let res: Option<serde_json::Value> = self
.send_message(Command::Navigate(NavigateParams {
url: url.into(),
}))
.await?;
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, std::io::Error> {
let res = self
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn send_message<T>(
&mut self,
command: Command,
) -> Result<Option<T>, std::io::Error>
where
T: DeserializeOwned,
{
let value = serde_json::to_value(command)?;
let (command, params) = (
value.get("command").unwrap().as_str().unwrap().into(),
value.get("params").cloned(),
);
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
let message = Message(0, msgid, Some(command), params);
let message = serde_json::to_string(&message)?;
let message = format!("{}:{}", message.len(), message);
trace!("Sending message: {}", message);
let (tx, rx) = oneshot::channel();
{
let mut sender = self.sender.lock().unwrap();
sender.insert(msgid, tx);
}
self.stream.write_all(message.as_bytes()).await?;
self.stream.flush().await?;
let Some(r) = rx.await.unwrap() else {
return Ok(None)
};
Ok(serde_json::from_value(r)?)
}
fn start_recv_loop<T>(
mut stream: BufReader<T>,
sender: Arc<Mutex<SenderMap>>,
) where
T: AsyncRead + Send + Unpin + 'static,
{
tokio::spawn(async move {
loop {
let buf = match Self::next_message(&mut stream).await {
Ok(b) => b,
Err(e) => {
error!("Error receiving message: {:?}", e);
break;
}
};
let msg: Message = match serde_json::from_slice(&buf[..]) {
Ok(m) => m,
Err(e) => {
warn!("Error parsing message: {}", e);
continue;
}
};
let msgid = msg.1;
let value = msg.3;
let mut sender = sender.lock().unwrap();
if let Some(s) = sender.remove(&msgid) {
if s.send(value).is_err() {
warn!("Failed to send result to caller");
}
} else {
warn!("Got unsolicited message {} ({:?})", msgid, value);
}
}
});
}
async fn next_message<T>(
stream: &mut BufReader<T>,
) -> Result<Vec<u8>, MessageError>
where
T: AsyncRead + Unpin,
{
let mut buf = vec![];
stream.read_until(b':', &mut buf).await?;
let length: usize =
std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?;
trace!("Message length: {:?}", length);
let mut buf = vec![0; length];
stream.read_exact(&mut buf[..]).await?;
trace!("Received message: {:?}", buf);
Ok(buf)
}
}