diff --git a/Cargo.lock b/Cargo.lock index 097d7ab..32d5947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,6 +13,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-trait" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d1d8ab452a3936018a687b20e6f7cf5363d713b732b8884001317b0e48aa3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -363,6 +374,7 @@ dependencies = [ name = "mqttmarionette" version = "0.1.0" dependencies = [ + "async-trait", "inotify", "mozprofile", "mozrunner", diff --git a/Cargo.toml b/Cargo.toml index d6e7262..1a69cc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +async-trait = "0.1.60" inotify = "0.10.0" mozprofile = "0.9.0" mozrunner = "0.15.0" diff --git a/src/config.rs b/src/config.rs index 563270f..1a18303 100644 --- a/src/config.rs +++ b/src/config.rs @@ -53,6 +53,8 @@ pub struct MqttConfig { pub username: Option, #[serde(default)] pub password: Option, + #[serde(default = "default_topic_prefix")] + pub topic_prefix: String, } impl Default for MqttConfig { @@ -64,6 +66,7 @@ impl Default for MqttConfig { ca_file: Default::default(), username: Default::default(), password: Default::default(), + topic_prefix: default_topic_prefix(), } } } @@ -81,6 +84,10 @@ const fn default_mqtt_port() -> u16 { 1883 } +fn default_topic_prefix() -> String { + "mqttmarionette".into() +} + pub fn load_config

(path: Option

) -> Result where P: AsRef, diff --git a/src/marionette/message.rs b/src/marionette/message.rs index 4fb3a98..bb174ca 100644 --- a/src/marionette/message.rs +++ b/src/marionette/message.rs @@ -62,9 +62,27 @@ pub struct NewSessionParams { pub strict_file_interactability: bool, } +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +pub struct GetTitleResponse { + pub value: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +pub struct NavigateParams { + pub url: String, +} + #[derive(Debug, Serialize)] #[serde(tag = "command", content = "params")] pub enum Command { #[serde(rename = "WebDriver:NewSession")] NewSession(NewSessionParams), + #[serde(rename = "WebDriver:GetTitle")] + GetTitle, + #[serde(rename = "WebDriver:Navigate")] + Navigate(NavigateParams), } diff --git a/src/marionette/mod.rs b/src/marionette/mod.rs index 1fa0716..041d5b4 100644 --- a/src/marionette/mod.rs +++ b/src/marionette/mod.rs @@ -17,7 +17,10 @@ use tokio::sync::oneshot; use tracing::{debug, error, trace, warn}; pub use error::{ConnectionError, MessageError}; -use message::{Command, Hello, NewSessionParams, NewSessionResponse}; +use message::{ + Command, GetTitleResponse, Hello, NavigateParams, NewSessionParams, + NewSessionResponse, +}; #[derive(Debug, Deserialize, Serialize)] struct Message(u8, u32, Option, Option); @@ -48,6 +51,26 @@ impl Marionette { Ok(Self { ts, stream, sender }) } + pub async fn get_title(&mut self) -> Result { + let res: GetTitleResponse = + self.send_message(Command::GetTitle).await?.unwrap(); + debug!("Received message: {:?}", res); + Ok(res.value) + } + + pub async fn navigate(&mut self, url: U) -> Result<(), std::io::Error> + where + U: Into, + { + let res: Option = self + .send_message(Command::Navigate(NavigateParams { + url: url.into(), + })) + .await?; + debug!("Received message: {:?}", res); + Ok(()) + } + pub async fn new_session( &mut self, ) -> Result { diff --git a/src/mqtt.rs b/src/mqtt.rs index 36061c6..bc16459 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,19 +1,36 @@ use std::time::Duration; -pub use paho_mqtt::Error; +use async_trait::async_trait; +use paho_mqtt::topic_matcher::TopicMatcher; +pub use paho_mqtt::{AsyncClient, Error, Message}; use paho_mqtt::{ - AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, - CreateOptionsBuilder, Message, ServerResponse, SslOptionsBuilder, + AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, + CreateOptionsBuilder, ServerResponse, SslOptionsBuilder, }; use tokio_stream::StreamExt; use tracing::{info, trace}; use crate::config::Configuration; +#[async_trait] +pub trait MessageHandler { + async fn navigate( + &mut self, + publisher: &MqttPublisher, + msg: &Message, + ); +} + +#[derive(Debug)] +pub enum MessageType { + Navigate, +} + pub struct MqttClient<'a> { config: &'a Configuration, client: AsyncClient, stream: AsyncReceiver>, + topics: TopicMatcher, } impl<'a> MqttClient<'a> { @@ -29,7 +46,13 @@ impl<'a> MqttClient<'a> { CreateOptionsBuilder::new().server_uri(uri).finalize(); let mut client = AsyncClient::new(client_opts)?; let stream = client.get_stream(10); - Ok(Self { config, client, stream }) + let topics = TopicMatcher::new(); + Ok(Self { + config, + client, + stream, + topics, + }) } pub async fn connect(&mut self) -> Result { @@ -38,10 +61,32 @@ impl<'a> MqttClient<'a> { Ok(res) } - pub async fn run(mut self) { + pub async fn subscribe(&mut self) -> Result { + let prefix = &self.config.mqtt.topic_prefix; + let t_nav = format!("{}/+/navigate", prefix); + let res = self.client.subscribe(&t_nav, 0).await?; + self.topics.insert(t_nav, MessageType::Navigate); + Ok(res) + } + + pub async fn run(mut self, mut handler: H) + where + H: MessageHandler, + { + let publisher = MqttPublisher { + config: self.config, + client: &mut self.client, + }; while let Some(msg) = self.stream.next().await { let Some(msg) = msg else {continue}; trace!("Received message: {:?}", msg); + for m in self.topics.matches(msg.topic()) { + match m.1 { + MessageType::Navigate => { + handler.navigate(&publisher, &msg).await; + } + } + } } } @@ -65,3 +110,34 @@ impl<'a> MqttClient<'a> { Ok(conn_opts.finalize()) } } + +pub struct MqttPublisher<'a> { + config: &'a Configuration, + client: &'a mut AsyncClient, +} + +impl<'a> MqttPublisher<'a> { + pub async fn publish_title( + &self, + screen: &str, + title: &str, + ) -> Result<(), Error> { + let topic = + format!("{}/{}/title", self.config.mqtt.topic_prefix, screen); + let msg = Message::new_retained(topic, title, 0); + self.client.publish(msg).await?; + Ok(()) + } + + pub async fn publish_url( + &self, + screen: &str, + url: &str, + ) -> Result<(), Error> { + let topic = + format!("{}/{}/url", self.config.mqtt.topic_prefix, screen); + let msg = Message::new_retained(topic, url, 0); + self.client.publish(msg).await?; + Ok(()) + } +} diff --git a/src/session.rs b/src/session.rs index d37ba5f..a3f8367 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,12 +1,12 @@ use std::time::Duration; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; use crate::browser::{Browser, BrowserError}; use crate::config::Configuration; use crate::marionette::error::ConnectionError; use crate::marionette::Marionette; -use crate::mqtt::MqttClient; +use crate::mqtt::{Message, MqttClient, MqttPublisher}; #[derive(Debug)] pub enum SessionError { @@ -83,15 +83,59 @@ impl Session { }) } - pub async fn run(&self) { + pub async fn run(mut self) { let mut client = MqttClient::new(&self.config).unwrap(); loop { - match client.connect().await { - Ok(_) => break, - Err(e) => warn!("Failed to connect to MQTT server: {}", e), + if let Err(e) = client.connect().await { + warn!("Failed to connect to MQTT server: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; } - tokio::time::sleep(Duration::from_secs(1)).await; + if let Err(e) = client.subscribe().await { + warn!("Error subscribing to MQTT topics: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + break; + } + let handler = MessageHandler { + marionette: &mut self.marionette, + }; + client.run(handler).await; + } +} + +pub struct MessageHandler<'a> { + marionette: &'a mut Marionette, +} + +#[async_trait::async_trait] +impl<'a> crate::mqtt::MessageHandler for MessageHandler<'a> { + async fn navigate(&mut self, publisher: &MqttPublisher, msg: &Message) { + let url = msg.payload_str(); + let parts: Vec<&str> = msg.topic().split('/').rev().collect(); + let screen = match parts.get(1) { + Some(&"") | None => { + warn!("Invalid navigate request: no screen"); + return; + } + Some(s) => s, + }; + debug!("Handling navigate request: {}", url); + info!("Navigate screen {} to {}", screen, url); + if let Err(e) = self.marionette.navigate(url.to_string()).await { + error!("Failed to navigate: {}", e); + } + if let Err(e) = publisher.publish_url(screen, &url).await { + error!("Failed to publish title: {}", e); + } + match self.marionette.get_title().await { + Ok(t) => { + if let Err(e) = publisher.publish_title(screen, &t).await { + error!("Failed to publish title: {}", e); + } + } + Err(e) => error!("Error getting title: {}", e), } - client.run().await; } }