Compare commits

..

No commits in common. "4eba92f4a0131b2472b5079ff36321d1151086bb" and "cef128d1efc55e17d5e94c62a61797de17438ae5" have entirely different histories.

8 changed files with 56 additions and 366 deletions

12
Cargo.lock generated
View File

@ -13,17 +13,6 @@ dependencies = [
"futures-core", "futures-core",
] ]
[[package]]
name = "async-trait"
version = "0.1.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d1d8ab452a3936018a687b20e6f7cf5363d713b732b8884001317b0e48aa3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.1.0" version = "1.1.0"
@ -374,7 +363,6 @@ dependencies = [
name = "mqttmarionette" name = "mqttmarionette"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-trait",
"inotify", "inotify",
"mozprofile", "mozprofile",
"mozrunner", "mozrunner",

View File

@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
async-trait = "0.1.60"
inotify = "0.10.0" inotify = "0.10.0"
mozprofile = "0.9.0" mozprofile = "0.9.0"
mozrunner = "0.15.0" mozrunner = "0.15.0"

View File

@ -53,8 +53,6 @@ pub struct MqttConfig {
pub username: Option<String>, pub username: Option<String>,
#[serde(default)] #[serde(default)]
pub password: Option<String>, pub password: Option<String>,
#[serde(default = "default_topic_prefix")]
pub topic_prefix: String,
} }
impl Default for MqttConfig { impl Default for MqttConfig {
@ -66,7 +64,6 @@ impl Default for MqttConfig {
ca_file: Default::default(), ca_file: Default::default(),
username: Default::default(), username: Default::default(),
password: Default::default(), password: Default::default(),
topic_prefix: default_topic_prefix(),
} }
} }
} }
@ -84,10 +81,6 @@ const fn default_mqtt_port() -> u16 {
1883 1883
} }
fn default_topic_prefix() -> String {
"mqttmarionette".into()
}
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError> pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
where where
P: AsRef<Path>, P: AsRef<Path>,

View File

@ -1,8 +1,6 @@
use std::num::ParseIntError; use std::num::ParseIntError;
use std::str::Utf8Error; use std::str::Utf8Error;
use serde::Deserialize;
#[derive(Debug)] #[derive(Debug)]
pub enum MessageError { pub enum MessageError {
Io(std::io::Error), Io(std::io::Error),
@ -92,63 +90,3 @@ impl std::error::Error for ConnectionError {
} }
} }
} }
#[derive(Debug, Deserialize)]
pub struct ErrorResponse {
pub error: String,
pub message: String,
pub stacktrace: String,
}
impl std::fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}: {}", self.error, self.message)
}
}
impl std::error::Error for ErrorResponse {}
#[derive(Debug)]
pub enum CommandError {
Io(std::io::Error),
Command(ErrorResponse),
Json(serde_json::Error),
}
impl From<std::io::Error> for CommandError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<ErrorResponse> for CommandError {
fn from(e: ErrorResponse) -> Self {
Self::Command(e)
}
}
impl From<serde_json::Error> for CommandError {
fn from(e: serde_json::Error) -> Self {
Self::Json(e)
}
}
impl std::fmt::Display for CommandError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Command(e) => write!(f, "Marionette command error: {}", e),
Self::Json(e) => write!(f, "JSON deserialization error: {}", e),
}
}
}
impl std::error::Error for CommandError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Command(e) => Some(e),
Self::Json(e) => Some(e),
}
}
}

View File

@ -62,36 +62,9 @@ pub struct NewSessionParams {
pub strict_file_interactability: bool, pub strict_file_interactability: bool,
} }
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct GetTitleResponse {
pub value: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct GetCurrentUrlResponse {
pub value: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct NavigateParams {
pub url: String,
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(tag = "command", content = "params")] #[serde(tag = "command", content = "params")]
pub enum Command { pub enum Command {
#[serde(rename = "WebDriver:NewSession")] #[serde(rename = "WebDriver:NewSession")]
NewSession(NewSessionParams), NewSession(NewSessionParams),
#[serde(rename = "WebDriver:GetTitle")]
GetTitle,
#[serde(rename = "WebDriver:Navigate")]
Navigate(NavigateParams),
#[serde(rename = "WebDriver:GetCurrentURL")]
GetCurrentUrl,
} }

View File

@ -16,34 +16,21 @@ use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
pub use error::{CommandError, ConnectionError, ErrorResponse, MessageError}; pub use error::{ConnectionError, MessageError};
use message::{ use message::{Command, Hello, NewSessionParams, NewSessionResponse};
Command, GetCurrentUrlResponse, GetTitleResponse, Hello, NavigateParams,
NewSessionParams, NewSessionResponse,
};
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct Message(u8, u32, String, Option<serde_json::Value>); struct Message(u8, u32, Option<String>, Option<serde_json::Value>);
#[derive(Debug, Deserialize, Serialize)] type SenderMap = HashMap<u32, oneshot::Sender<Option<serde_json::Value>>>;
struct Response(
u8,
u32,
Option<serde_json::Value>,
Option<serde_json::Value>,
);
type CommandResult = Result<Option<serde_json::Value>, ErrorResponse>; pub struct Marionette {
type SenderMap = HashMap<u32, oneshot::Sender<CommandResult>>;
pub struct MarionetteConnection {
ts: Instant, ts: Instant,
stream: BufWriter<OwnedWriteHalf>, stream: BufWriter<OwnedWriteHalf>,
sender: Arc<Mutex<SenderMap>>, sender: Arc<Mutex<SenderMap>>,
} }
impl MarionetteConnection { impl Marionette {
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError> pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
where where
A: ToSocketAddrs, A: ToSocketAddrs,
@ -61,10 +48,23 @@ impl MarionetteConnection {
Ok(Self { ts, stream, sender }) Ok(Self { ts, stream, sender })
} }
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, std::io::Error> {
let res = self
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn send_message<T>( pub async fn send_message<T>(
&mut self, &mut self,
command: Command, command: Command,
) -> Result<Option<T>, CommandError> ) -> Result<Option<T>, std::io::Error>
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
@ -74,7 +74,7 @@ impl MarionetteConnection {
value.get("params").cloned(), value.get("params").cloned(),
); );
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32; let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
let message = Message(0, msgid, command, params); let message = Message(0, msgid, Some(command), params);
let message = serde_json::to_string(&message)?; let message = serde_json::to_string(&message)?;
let message = format!("{}:{}", message.len(), message); let message = format!("{}:{}", message.len(), message);
trace!("Sending message: {}", message); trace!("Sending message: {}", message);
@ -85,10 +85,10 @@ impl MarionetteConnection {
} }
self.stream.write_all(message.as_bytes()).await?; self.stream.write_all(message.as_bytes()).await?;
self.stream.flush().await?; self.stream.flush().await?;
match rx.await.unwrap()? { let Some(r) = rx.await.unwrap() else {
Some(r) => Ok(serde_json::from_value(r)?), return Ok(None)
None => Ok(None), };
} Ok(serde_json::from_value(r)?)
} }
fn start_recv_loop<T>( fn start_recv_loop<T>(
@ -106,7 +106,7 @@ impl MarionetteConnection {
break; break;
} }
}; };
let msg: Response = match serde_json::from_slice(&buf[..]) { let msg: Message = match serde_json::from_slice(&buf[..]) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
warn!("Error parsing message: {}", e); warn!("Error parsing message: {}", e);
@ -114,23 +114,10 @@ impl MarionetteConnection {
} }
}; };
let msgid = msg.1; let msgid = msg.1;
let error = msg.2;
let value = msg.3; let value = msg.3;
let mut sender = sender.lock().unwrap(); let mut sender = sender.lock().unwrap();
if let Some(s) = sender.remove(&msgid) { if let Some(s) = sender.remove(&msgid) {
let r; if s.send(value).is_err() {
if let Some(e) = error {
match serde_json::from_value(e) {
Ok(e) => r = Err(e),
Err(e) => {
warn!("Error parsing error response: {}", e);
continue;
}
}
} else {
r = Ok(value);
}
if s.send(r).is_err() {
warn!("Failed to send result to caller"); warn!("Failed to send result to caller");
} }
} else { } else {
@ -157,55 +144,3 @@ impl MarionetteConnection {
Ok(buf) Ok(buf)
} }
} }
pub struct Marionette {
conn: MarionetteConnection,
}
impl Marionette {
pub fn new(conn: MarionetteConnection) -> Self {
Self { conn }
}
pub async fn get_title(&mut self) -> Result<String, CommandError> {
let res: GetTitleResponse =
self.conn.send_message(Command::GetTitle).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn get_current_url(&mut self) -> Result<String, CommandError> {
let res: GetCurrentUrlResponse =
self.conn.send_message(Command::GetCurrentUrl).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn navigate<U>(&mut self, url: U) -> Result<(), CommandError>
where
U: Into<String>,
{
let res: Option<serde_json::Value> = self
.conn
.send_message(Command::Navigate(NavigateParams {
url: url.into(),
}))
.await?;
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, CommandError> {
let res = self
.conn
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
}

View File

@ -1,40 +1,22 @@
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; pub use paho_mqtt::Error;
use paho_mqtt::topic_matcher::TopicMatcher;
pub use paho_mqtt::{AsyncClient, Error, Message};
use paho_mqtt::{ use paho_mqtt::{
AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder,
CreateOptionsBuilder, ServerResponse, SslOptionsBuilder, CreateOptionsBuilder, Message, SslOptionsBuilder,
}; };
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{info, trace}; use tracing::{info, trace};
use crate::config::Configuration; use crate::config::Configuration;
#[async_trait] pub struct MqttClient {
pub trait MessageHandler {
async fn navigate(
&mut self,
publisher: &MqttPublisher,
msg: &Message,
);
}
#[derive(Debug)]
pub enum MessageType {
Navigate,
}
pub struct MqttClient<'a> {
config: &'a Configuration,
client: AsyncClient, client: AsyncClient,
stream: AsyncReceiver<Option<Message>>, stream: AsyncReceiver<Option<Message>>,
topics: TopicMatcher<MessageType>,
} }
impl<'a> MqttClient<'a> { impl MqttClient {
pub fn new(config: &'a Configuration) -> Result<Self, Error> { pub async fn new(config: &Configuration) -> Result<MqttClient, Error> {
let uri = format!( let uri = format!(
"{}://{}:{}", "{}://{}:{}",
if config.mqtt.tls { "ssl" } else { "tcp" }, if config.mqtt.tls { "ssl" } else { "tcp" },
@ -46,98 +28,36 @@ impl<'a> MqttClient<'a> {
CreateOptionsBuilder::new().server_uri(uri).finalize(); CreateOptionsBuilder::new().server_uri(uri).finalize();
let mut client = AsyncClient::new(client_opts)?; let mut client = AsyncClient::new(client_opts)?;
let stream = client.get_stream(10); let stream = client.get_stream(10);
let topics = TopicMatcher::new(); client.connect(Self::conn_opts(config)?).await?;
Ok(Self {
config,
client,
stream,
topics,
})
}
pub async fn connect(&mut self) -> Result<ServerResponse, Error> {
let res = self.client.connect(self.conn_opts()?).await?;
info!("Successfully connected to MQTT broker"); info!("Successfully connected to MQTT broker");
Ok(res)
Ok(Self { client, stream })
} }
pub async fn subscribe(&mut self) -> Result<ServerResponse, Error> { pub async fn run(mut self) {
let prefix = &self.config.mqtt.topic_prefix;
let t_nav = format!("{}/+/navigate", prefix);
let res = self.client.subscribe(&t_nav, 0).await?;
self.topics.insert(t_nav, MessageType::Navigate);
Ok(res)
}
pub async fn run<H>(mut self, mut handler: H)
where
H: MessageHandler,
{
let publisher = MqttPublisher {
config: self.config,
client: &mut self.client,
};
while let Some(msg) = self.stream.next().await { while let Some(msg) = self.stream.next().await {
let Some(msg) = msg else {continue}; let Some(msg) = msg else {continue};
trace!("Received message: {:?}", msg); trace!("Received message: {:?}", msg);
for m in self.topics.matches(msg.topic()) {
match m.1 {
MessageType::Navigate => {
handler.navigate(&publisher, &msg).await;
}
}
}
} }
} }
fn conn_opts(&self) -> Result<ConnectOptions, Error> { fn conn_opts(config: &Configuration) -> Result<ConnectOptions, Error> {
let mut conn_opts = ConnectOptionsBuilder::new(); let mut conn_opts = ConnectOptionsBuilder::new();
conn_opts.automatic_reconnect( conn_opts.automatic_reconnect(
Duration::from_millis(500), Duration::from_millis(500),
Duration::from_secs(30), Duration::from_secs(30),
); );
if self.config.mqtt.tls { if config.mqtt.tls {
let ssl_opts = SslOptionsBuilder::new() let ssl_opts = SslOptionsBuilder::new()
.trust_store(&self.config.mqtt.ca_file)? .trust_store(&config.mqtt.ca_file)?
.finalize(); .finalize();
conn_opts.ssl_options(ssl_opts); conn_opts.ssl_options(ssl_opts);
} }
if let [Some(username), Some(password)] = if let [Some(username), Some(password)] =
[&self.config.mqtt.username, &self.config.mqtt.password] [&config.mqtt.username, &config.mqtt.password]
{ {
conn_opts.user_name(username).password(password); conn_opts.user_name(username).password(password);
} }
Ok(conn_opts.finalize()) Ok(conn_opts.finalize())
} }
} }
pub struct MqttPublisher<'a> {
config: &'a Configuration,
client: &'a mut AsyncClient,
}
impl<'a> MqttPublisher<'a> {
pub async fn publish_title(
&self,
screen: &str,
title: &str,
) -> Result<(), Error> {
let topic =
format!("{}/{}/title", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, title, 0);
self.client.publish(msg).await?;
Ok(())
}
pub async fn publish_url(
&self,
screen: &str,
url: &str,
) -> Result<(), Error> {
let topic =
format!("{}/{}/url", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, url, 0);
self.client.publish(msg).await?;
Ok(())
}
}

View File

@ -1,12 +1,12 @@
use std::time::Duration; use std::time::Duration;
use tracing::{debug, error, info, warn}; use tracing::{debug, info, warn};
use crate::mqtt::MqttClient;
use crate::browser::{Browser, BrowserError}; use crate::browser::{Browser, BrowserError};
use crate::config::Configuration; use crate::config::Configuration;
use crate::marionette::error::{CommandError, ConnectionError}; use crate::marionette::error::ConnectionError;
use crate::marionette::{Marionette, MarionetteConnection}; use crate::marionette::Marionette;
use crate::mqtt::{Message, MqttClient, MqttPublisher};
#[derive(Debug)] #[derive(Debug)]
pub enum SessionError { pub enum SessionError {
@ -14,7 +14,6 @@ pub enum SessionError {
Io(std::io::Error), Io(std::io::Error),
Connection(ConnectionError), Connection(ConnectionError),
InvalidState(String), InvalidState(String),
Command(CommandError),
} }
impl From<BrowserError> for SessionError { impl From<BrowserError> for SessionError {
@ -35,12 +34,6 @@ impl From<ConnectionError> for SessionError {
} }
} }
impl From<CommandError> for SessionError {
fn from(e: CommandError) -> Self {
Self::Command(e)
}
}
impl std::fmt::Display for SessionError { impl std::fmt::Display for SessionError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {
@ -48,7 +41,6 @@ impl std::fmt::Display for SessionError {
Self::Io(e) => write!(f, "I/O error: {}", e), Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Connection(e) => write!(f, "Connection error: {}", e), Self::Connection(e) => write!(f, "Connection error: {}", e),
Self::InvalidState(e) => write!(f, "Invalid state: {}", e), Self::InvalidState(e) => write!(f, "Invalid state: {}", e),
Self::Command(e) => write!(f, "Marionette command failed: {}", e),
} }
} }
} }
@ -60,7 +52,6 @@ impl std::error::Error for SessionError {
Self::Io(e) => Some(e), Self::Io(e) => Some(e),
Self::Connection(e) => Some(e), Self::Connection(e) => Some(e),
Self::InvalidState(_) => None, Self::InvalidState(_) => None,
Self::Command(e) => Some(e),
} }
} }
} }
@ -81,8 +72,7 @@ impl Session {
return Err(SessionError::InvalidState("No active Marionette port".into())); return Err(SessionError::InvalidState("No active Marionette port".into()));
}; };
debug!("Connecting to Firefox Marionette on port {}", port); debug!("Connecting to Firefox Marionette on port {}", port);
let conn = MarionetteConnection::connect(("127.0.0.1", port)).await?; let mut marionette = Marionette::connect(("127.0.0.1", port)).await?;
let mut marionette = Marionette::new(conn);
info!("Successfully connected to Firefox Marionette"); info!("Successfully connected to Firefox Marionette");
let ses = marionette.new_session().await?; let ses = marionette.new_session().await?;
debug!("Started Marionette session {}", ses.session_id); debug!("Started Marionette session {}", ses.session_id);
@ -93,64 +83,18 @@ impl Session {
}) })
} }
pub async fn run(mut self) { pub async fn run(&self) {
let mut client = MqttClient::new(&self.config).unwrap(); let client;
loop { loop {
if let Err(e) = client.connect().await { match MqttClient::new(&self.config).await {
warn!("Failed to connect to MQTT server: {}", e); Ok(c) => {
tokio::time::sleep(Duration::from_secs(1)).await; client = c;
continue; break;
}
Err(e) => warn!("Failed to connect to MQTT server: {}", e),
} }
if let Err(e) = client.subscribe().await { tokio::time::sleep(Duration::from_secs(1)).await;
warn!("Error subscribing to MQTT topics: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
break;
}
let handler = MessageHandler {
marionette: &mut self.marionette,
};
client.run(handler).await;
}
}
pub struct MessageHandler<'a> {
marionette: &'a mut Marionette,
}
#[async_trait::async_trait]
impl<'a> crate::mqtt::MessageHandler for MessageHandler<'a> {
async fn navigate(&mut self, publisher: &MqttPublisher, msg: &Message) {
let url = msg.payload_str();
let parts: Vec<&str> = msg.topic().split('/').rev().collect();
let screen = match parts.get(1) {
Some(&"") | None => {
warn!("Invalid navigate request: no screen");
return;
}
Some(s) => s,
};
debug!("Handling navigate request: {}", url);
info!("Navigate screen {} to {}", screen, url);
if let Err(e) = self.marionette.navigate(url.to_string()).await {
error!("Failed to navigate: {}", e);
}
match self.marionette.get_current_url().await {
Ok(u) => {
if let Err(e) = publisher.publish_url(screen, &u).await {
error!("Failed to publish URL: {}", e);
}
}
Err(e) => error!("Failed to get current browser URL: {}", e),
}
match self.marionette.get_title().await {
Ok(t) => {
if let Err(e) = publisher.publish_title(screen, &t).await {
error!("Failed to publish title: {}", e);
}
}
Err(e) => error!("Error getting title: {}", e),
} }
client.run().await;
} }
} }