mqttmarionette/src/marionette/mod.rs

312 lines
9.1 KiB
Rust

pub mod error;
pub mod message;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{
AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader,
BufWriter,
};
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn};
pub use error::{CommandError, ConnectionError, ErrorResponse, MessageError};
use message::{
CloseWindowParams, Command, GetCurrentUrlResponse, GetTitleResponse,
Hello, NavigateParams, NewSessionParams, NewSessionResponse,
NewWindowParams, NewWindowResponse, SwitchToWindowParams, WindowRect,
WindowType,
};
#[derive(Debug, Deserialize, Serialize)]
struct Message(u8, u32, String, Option<serde_json::Value>);
#[derive(Debug, Deserialize, Serialize)]
struct Response(
u8,
u32,
Option<serde_json::Value>,
Option<serde_json::Value>,
);
type CommandResult = Result<Option<serde_json::Value>, ErrorResponse>;
type SenderMap = HashMap<u32, oneshot::Sender<CommandResult>>;
pub struct MarionetteConnection {
ts: Instant,
stream: BufWriter<OwnedWriteHalf>,
sender: Arc<Mutex<SenderMap>>,
}
impl MarionetteConnection {
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
where
A: ToSocketAddrs,
{
let conn = TcpStream::connect(addr).await?;
let (read, write) = conn.into_split();
let stream = BufWriter::new(write);
let mut rstream = BufReader::new(read);
let ts = Instant::now();
let sender = Arc::new(Mutex::new(HashMap::new()));
let buf = Self::next_message(&mut rstream).await?;
let hello: Hello = serde_json::from_slice(&buf)?;
debug!("Received hello: {:?}", hello);
Self::start_recv_loop(rstream, sender.clone());
Ok(Self { ts, stream, sender })
}
pub async fn send_message<T>(
&mut self,
command: Command,
) -> Result<Option<T>, CommandError>
where
T: DeserializeOwned,
{
let value = serde_json::to_value(command)?;
let (command, params) = (
value.get("command").unwrap().as_str().unwrap().into(),
value.get("params").cloned(),
);
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
let message = Message(0, msgid, command, params);
let message = serde_json::to_string(&message)?;
let message = format!("{}:{}", message.len(), message);
trace!("Sending message: {}", message);
let (tx, rx) = oneshot::channel();
{
let mut sender = self.sender.lock().unwrap();
sender.insert(msgid, tx);
}
self.stream.write_all(message.as_bytes()).await?;
self.stream.flush().await?;
match rx.await.unwrap()? {
Some(r) => Ok(serde_json::from_value(r)?),
None => Ok(None),
}
}
fn start_recv_loop<T>(
mut stream: BufReader<T>,
sender: Arc<Mutex<SenderMap>>,
) where
T: AsyncRead + Send + Unpin + 'static,
{
tokio::spawn(async move {
loop {
let buf = match Self::next_message(&mut stream).await {
Ok(b) => b,
Err(e) => {
error!("Error receiving message: {:?}", e);
break;
}
};
let msg: Response = match serde_json::from_slice(&buf[..]) {
Ok(m) => m,
Err(e) => {
warn!("Error parsing message: {}", e);
continue;
}
};
let msgid = msg.1;
let error = msg.2;
let value = msg.3;
let mut sender = sender.lock().unwrap();
if let Some(s) = sender.remove(&msgid) {
let r;
if let Some(e) = error {
match serde_json::from_value(e) {
Ok(e) => r = Err(e),
Err(e) => {
warn!("Error parsing error response: {}", e);
continue;
}
}
} else {
r = Ok(value);
}
if s.send(r).is_err() {
warn!("Failed to send result to caller");
}
} else {
warn!("Got unsolicited message {} ({:?})", msgid, value);
}
}
});
}
async fn next_message<T>(
stream: &mut BufReader<T>,
) -> Result<Vec<u8>, MessageError>
where
T: AsyncRead + Unpin,
{
let mut buf = vec![];
stream.read_until(b':', &mut buf).await?;
if buf.is_empty() {
return Err(MessageError::Disconnected);
}
let length: usize =
std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?;
trace!("Message length: {:?}", length);
let mut buf = vec![0; length];
stream.read_exact(&mut buf[..]).await?;
trace!("Received message: {:?}", buf);
Ok(buf)
}
}
pub struct Marionette {
conn: MarionetteConnection,
}
impl Marionette {
pub fn new(conn: MarionetteConnection) -> Self {
Self { conn }
}
pub async fn close_window(
&mut self,
handle: impl Into<String>,
) -> Result<(), CommandError> {
let res: Vec<String> = self
.conn
.send_message(Command::CloseWindow(CloseWindowParams {
handle: handle.into(),
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn fullscreen(&mut self) -> Result<(), CommandError> {
let res: serde_json::Value = self
.conn
.send_message(Command::FullscreenWindow)
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn get_title(&mut self) -> Result<String, CommandError> {
let res: GetTitleResponse =
self.conn.send_message(Command::GetTitle).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn get_current_url(&mut self) -> Result<String, CommandError> {
let res: GetCurrentUrlResponse = self
.conn
.send_message(Command::GetCurrentUrl)
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn get_window_handles(
&mut self,
) -> Result<Vec<String>, CommandError> {
let res = self
.conn
.send_message(Command::GetWindowHandles)
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn navigate<U>(&mut self, url: U) -> Result<(), CommandError>
where
U: Into<String>,
{
let res: Option<serde_json::Value> = self
.conn
.send_message(Command::Navigate(NavigateParams {
url: url.into(),
}))
.await?;
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, CommandError> {
let res = self
.conn
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn new_window(
&mut self,
window_type: WindowType,
focus: bool,
) -> Result<String, CommandError> {
let res: NewWindowResponse = self
.conn
.send_message(Command::NewWindow(NewWindowParams {
window_type,
focus,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res.handle)
}
pub async fn set_window_rect(
&mut self,
x: Option<i32>,
y: Option<i32>,
height: Option<u32>,
width: Option<u32>,
) -> Result<WindowRect, CommandError> {
let res: WindowRect = self
.conn
.send_message(Command::SetWindowRect(WindowRect {
x,
y,
height,
width,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn switch_to_window(
&mut self,
handle: String,
focus: bool,
) -> Result<(), CommandError> {
let res: serde_json::Value = self
.conn
.send_message(Command::SwitchToWindow(SwitchToWindowParams {
handle,
focus,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(())
}
}