170 lines
5.3 KiB
Rust
170 lines
5.3 KiB
Rust
pub mod error;
|
|
pub mod message;
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::Instant;
|
|
|
|
use serde::de::DeserializeOwned;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::io::{
|
|
AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader,
|
|
BufWriter,
|
|
};
|
|
use tokio::net::tcp::OwnedWriteHalf;
|
|
use tokio::net::{TcpStream, ToSocketAddrs};
|
|
use tokio::sync::oneshot;
|
|
use tracing::{debug, error, trace, warn};
|
|
|
|
pub use error::{ConnectionError, MessageError};
|
|
use message::{
|
|
Command, GetTitleResponse, Hello, NavigateParams, NewSessionParams,
|
|
NewSessionResponse,
|
|
};
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
struct Message(u8, u32, Option<String>, Option<serde_json::Value>);
|
|
|
|
type SenderMap = HashMap<u32, oneshot::Sender<Option<serde_json::Value>>>;
|
|
|
|
pub struct Marionette {
|
|
ts: Instant,
|
|
stream: BufWriter<OwnedWriteHalf>,
|
|
sender: Arc<Mutex<SenderMap>>,
|
|
}
|
|
|
|
impl Marionette {
|
|
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
|
|
where
|
|
A: ToSocketAddrs,
|
|
{
|
|
let conn = TcpStream::connect(addr).await?;
|
|
let (read, write) = conn.into_split();
|
|
let stream = BufWriter::new(write);
|
|
let mut rstream = BufReader::new(read);
|
|
let ts = Instant::now();
|
|
let sender = Arc::new(Mutex::new(HashMap::new()));
|
|
let buf = Self::next_message(&mut rstream).await?;
|
|
let hello: Hello = serde_json::from_slice(&buf)?;
|
|
debug!("Received hello: {:?}", hello);
|
|
Self::start_recv_loop(rstream, sender.clone());
|
|
Ok(Self { ts, stream, sender })
|
|
}
|
|
|
|
pub async fn get_title(&mut self) -> Result<String, std::io::Error> {
|
|
let res: GetTitleResponse =
|
|
self.send_message(Command::GetTitle).await?.unwrap();
|
|
debug!("Received message: {:?}", res);
|
|
Ok(res.value)
|
|
}
|
|
|
|
pub async fn navigate<U>(&mut self, url: U) -> Result<(), std::io::Error>
|
|
where
|
|
U: Into<String>,
|
|
{
|
|
let res: Option<serde_json::Value> = self
|
|
.send_message(Command::Navigate(NavigateParams {
|
|
url: url.into(),
|
|
}))
|
|
.await?;
|
|
debug!("Received message: {:?}", res);
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn new_session(
|
|
&mut self,
|
|
) -> Result<NewSessionResponse, std::io::Error> {
|
|
let res = self
|
|
.send_message(Command::NewSession(NewSessionParams {
|
|
strict_file_interactability: true,
|
|
}))
|
|
.await?
|
|
.unwrap();
|
|
debug!("Received message: {:?}", res);
|
|
Ok(res)
|
|
}
|
|
|
|
pub async fn send_message<T>(
|
|
&mut self,
|
|
command: Command,
|
|
) -> Result<Option<T>, std::io::Error>
|
|
where
|
|
T: DeserializeOwned,
|
|
{
|
|
let value = serde_json::to_value(command)?;
|
|
let (command, params) = (
|
|
value.get("command").unwrap().as_str().unwrap().into(),
|
|
value.get("params").cloned(),
|
|
);
|
|
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
|
|
let message = Message(0, msgid, Some(command), params);
|
|
let message = serde_json::to_string(&message)?;
|
|
let message = format!("{}:{}", message.len(), message);
|
|
trace!("Sending message: {}", message);
|
|
let (tx, rx) = oneshot::channel();
|
|
{
|
|
let mut sender = self.sender.lock().unwrap();
|
|
sender.insert(msgid, tx);
|
|
}
|
|
self.stream.write_all(message.as_bytes()).await?;
|
|
self.stream.flush().await?;
|
|
let Some(r) = rx.await.unwrap() else {
|
|
return Ok(None)
|
|
};
|
|
Ok(serde_json::from_value(r)?)
|
|
}
|
|
|
|
fn start_recv_loop<T>(
|
|
mut stream: BufReader<T>,
|
|
sender: Arc<Mutex<SenderMap>>,
|
|
) where
|
|
T: AsyncRead + Send + Unpin + 'static,
|
|
{
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let buf = match Self::next_message(&mut stream).await {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
error!("Error receiving message: {:?}", e);
|
|
break;
|
|
}
|
|
};
|
|
let msg: Message = match serde_json::from_slice(&buf[..]) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
warn!("Error parsing message: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
let msgid = msg.1;
|
|
let value = msg.3;
|
|
let mut sender = sender.lock().unwrap();
|
|
if let Some(s) = sender.remove(&msgid) {
|
|
if s.send(value).is_err() {
|
|
warn!("Failed to send result to caller");
|
|
}
|
|
} else {
|
|
warn!("Got unsolicited message {} ({:?})", msgid, value);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
async fn next_message<T>(
|
|
stream: &mut BufReader<T>,
|
|
) -> Result<Vec<u8>, MessageError>
|
|
where
|
|
T: AsyncRead + Unpin,
|
|
{
|
|
let mut buf = vec![];
|
|
stream.read_until(b':', &mut buf).await?;
|
|
let length: usize =
|
|
std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?;
|
|
trace!("Message length: {:?}", length);
|
|
let mut buf = vec![0; length];
|
|
stream.read_exact(&mut buf[..]).await?;
|
|
trace!("Received message: {:?}", buf);
|
|
Ok(buf)
|
|
}
|
|
}
|