Compare commits

...

5 Commits

Author SHA1 Message Date
Dustin 4eba92f4a0 Publish final URL after navigation
If the URL specified in a navigation command results in a redirect, we
want to publish the final destination, rather than the provided
location.  Thus, after navigation completes, we get the browser's
current URL and publish that instead.
2022-12-31 10:43:01 -06:00
Dustin 8431a33f20 marionette: Separate commands from connection
The `MarionetteConnection` structure provides the low-level interface to
communicate with the Marionette server via the TCP socket.  In contrast,
the `Marionette` structure provides the high-level interface to execute
Marionette commands.  Separating these will make the code a bit cleaner,
in my opinion.
2022-12-31 10:28:03 -06:00
Dustin 3b1be2d01c marionette: Handle error responses
Marionette commands can return error responses, e.g. if a command fails
or is invalid.  We want to propagate these back to the caller whenever
possible.  The error object is specified in the Marionette protocol
documentation, so we can model it appropriately.
2022-12-30 22:04:44 -06:00
Dustin 4820d0f6cd Begin MQTT control implementation
The pieces are starting to come together.  To control the browser via
MQTT messages, the `MqttClient` dispatches messages via a
`MessageHandler`, which parses them and makes the appropriate Marionette
requests.  The `MessageHandler` trait defines callback methods for each
MQTT control operation, which currently is just `navigate`.  The
operation type is determined by the MQTT topic on which the message was
received.

Several new types are necessary to make this work.  The `MessageHandler`
trait and implementation are of course the core, reacting to incoming
MQTT messages.  In order for the handler to be able to *send* MQTT
messages, though, it needs a reference to the Paho MQTT client.  The
`MqttPublisher` provides a convenient wrapper around the client, with
specific methods for each type of message to send.  Finally, there's the
`MessageType` enumeration, which works in conjunction with the
`TopicMatcher` to match topic names to message types using topic filter
patterns.
2022-12-30 19:06:27 -06:00
Dustin 41c87d87af mqtt: Add MqttClient::connect method
Separating the `connect` call out of the `MqttClient::new` function
makes is such that we do not have to create a new object for each
iteration of the initial connection loop.  Instead, we just create one
object and repeatedly call its `connect` method until it succeeds
2022-12-30 14:39:10 -06:00
8 changed files with 369 additions and 59 deletions

12
Cargo.lock generated
View File

@ -13,6 +13,17 @@ dependencies = [
"futures-core",
]
[[package]]
name = "async-trait"
version = "0.1.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d1d8ab452a3936018a687b20e6f7cf5363d713b732b8884001317b0e48aa3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@ -363,6 +374,7 @@ dependencies = [
name = "mqttmarionette"
version = "0.1.0"
dependencies = [
"async-trait",
"inotify",
"mozprofile",
"mozrunner",

View File

@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = "0.1.60"
inotify = "0.10.0"
mozprofile = "0.9.0"
mozrunner = "0.15.0"

View File

@ -53,6 +53,8 @@ pub struct MqttConfig {
pub username: Option<String>,
#[serde(default)]
pub password: Option<String>,
#[serde(default = "default_topic_prefix")]
pub topic_prefix: String,
}
impl Default for MqttConfig {
@ -64,6 +66,7 @@ impl Default for MqttConfig {
ca_file: Default::default(),
username: Default::default(),
password: Default::default(),
topic_prefix: default_topic_prefix(),
}
}
}
@ -81,6 +84,10 @@ const fn default_mqtt_port() -> u16 {
1883
}
fn default_topic_prefix() -> String {
"mqttmarionette".into()
}
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
where
P: AsRef<Path>,

View File

@ -1,6 +1,8 @@
use std::num::ParseIntError;
use std::str::Utf8Error;
use serde::Deserialize;
#[derive(Debug)]
pub enum MessageError {
Io(std::io::Error),
@ -90,3 +92,63 @@ impl std::error::Error for ConnectionError {
}
}
}
#[derive(Debug, Deserialize)]
pub struct ErrorResponse {
pub error: String,
pub message: String,
pub stacktrace: String,
}
impl std::fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}: {}", self.error, self.message)
}
}
impl std::error::Error for ErrorResponse {}
#[derive(Debug)]
pub enum CommandError {
Io(std::io::Error),
Command(ErrorResponse),
Json(serde_json::Error),
}
impl From<std::io::Error> for CommandError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<ErrorResponse> for CommandError {
fn from(e: ErrorResponse) -> Self {
Self::Command(e)
}
}
impl From<serde_json::Error> for CommandError {
fn from(e: serde_json::Error) -> Self {
Self::Json(e)
}
}
impl std::fmt::Display for CommandError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Command(e) => write!(f, "Marionette command error: {}", e),
Self::Json(e) => write!(f, "JSON deserialization error: {}", e),
}
}
}
impl std::error::Error for CommandError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Command(e) => Some(e),
Self::Json(e) => Some(e),
}
}
}

View File

@ -62,9 +62,36 @@ pub struct NewSessionParams {
pub strict_file_interactability: bool,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct GetTitleResponse {
pub value: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct GetCurrentUrlResponse {
pub value: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct NavigateParams {
pub url: String,
}
#[derive(Debug, Serialize)]
#[serde(tag = "command", content = "params")]
pub enum Command {
#[serde(rename = "WebDriver:NewSession")]
NewSession(NewSessionParams),
#[serde(rename = "WebDriver:GetTitle")]
GetTitle,
#[serde(rename = "WebDriver:Navigate")]
Navigate(NavigateParams),
#[serde(rename = "WebDriver:GetCurrentURL")]
GetCurrentUrl,
}

View File

@ -16,21 +16,34 @@ use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn};
pub use error::{ConnectionError, MessageError};
use message::{Command, Hello, NewSessionParams, NewSessionResponse};
pub use error::{CommandError, ConnectionError, ErrorResponse, MessageError};
use message::{
Command, GetCurrentUrlResponse, GetTitleResponse, Hello, NavigateParams,
NewSessionParams, NewSessionResponse,
};
#[derive(Debug, Deserialize, Serialize)]
struct Message(u8, u32, Option<String>, Option<serde_json::Value>);
struct Message(u8, u32, String, Option<serde_json::Value>);
type SenderMap = HashMap<u32, oneshot::Sender<Option<serde_json::Value>>>;
#[derive(Debug, Deserialize, Serialize)]
struct Response(
u8,
u32,
Option<serde_json::Value>,
Option<serde_json::Value>,
);
pub struct Marionette {
type CommandResult = Result<Option<serde_json::Value>, ErrorResponse>;
type SenderMap = HashMap<u32, oneshot::Sender<CommandResult>>;
pub struct MarionetteConnection {
ts: Instant,
stream: BufWriter<OwnedWriteHalf>,
sender: Arc<Mutex<SenderMap>>,
}
impl Marionette {
impl MarionetteConnection {
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
where
A: ToSocketAddrs,
@ -48,23 +61,10 @@ impl Marionette {
Ok(Self { ts, stream, sender })
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, std::io::Error> {
let res = self
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
pub async fn send_message<T>(
&mut self,
command: Command,
) -> Result<Option<T>, std::io::Error>
) -> Result<Option<T>, CommandError>
where
T: DeserializeOwned,
{
@ -74,7 +74,7 @@ impl Marionette {
value.get("params").cloned(),
);
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
let message = Message(0, msgid, Some(command), params);
let message = Message(0, msgid, command, params);
let message = serde_json::to_string(&message)?;
let message = format!("{}:{}", message.len(), message);
trace!("Sending message: {}", message);
@ -85,10 +85,10 @@ impl Marionette {
}
self.stream.write_all(message.as_bytes()).await?;
self.stream.flush().await?;
let Some(r) = rx.await.unwrap() else {
return Ok(None)
};
Ok(serde_json::from_value(r)?)
match rx.await.unwrap()? {
Some(r) => Ok(serde_json::from_value(r)?),
None => Ok(None),
}
}
fn start_recv_loop<T>(
@ -106,7 +106,7 @@ impl Marionette {
break;
}
};
let msg: Message = match serde_json::from_slice(&buf[..]) {
let msg: Response = match serde_json::from_slice(&buf[..]) {
Ok(m) => m,
Err(e) => {
warn!("Error parsing message: {}", e);
@ -114,10 +114,23 @@ impl Marionette {
}
};
let msgid = msg.1;
let error = msg.2;
let value = msg.3;
let mut sender = sender.lock().unwrap();
if let Some(s) = sender.remove(&msgid) {
if s.send(value).is_err() {
let r;
if let Some(e) = error {
match serde_json::from_value(e) {
Ok(e) => r = Err(e),
Err(e) => {
warn!("Error parsing error response: {}", e);
continue;
}
}
} else {
r = Ok(value);
}
if s.send(r).is_err() {
warn!("Failed to send result to caller");
}
} else {
@ -144,3 +157,55 @@ impl Marionette {
Ok(buf)
}
}
pub struct Marionette {
conn: MarionetteConnection,
}
impl Marionette {
pub fn new(conn: MarionetteConnection) -> Self {
Self { conn }
}
pub async fn get_title(&mut self) -> Result<String, CommandError> {
let res: GetTitleResponse =
self.conn.send_message(Command::GetTitle).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn get_current_url(&mut self) -> Result<String, CommandError> {
let res: GetCurrentUrlResponse =
self.conn.send_message(Command::GetCurrentUrl).await?.unwrap();
debug!("Received message: {:?}", res);
Ok(res.value)
}
pub async fn navigate<U>(&mut self, url: U) -> Result<(), CommandError>
where
U: Into<String>,
{
let res: Option<serde_json::Value> = self
.conn
.send_message(Command::Navigate(NavigateParams {
url: url.into(),
}))
.await?;
debug!("Received message: {:?}", res);
Ok(())
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, CommandError> {
let res = self
.conn
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(res)
}
}

View File

@ -1,22 +1,40 @@
use std::time::Duration;
pub use paho_mqtt::Error;
use async_trait::async_trait;
use paho_mqtt::topic_matcher::TopicMatcher;
pub use paho_mqtt::{AsyncClient, Error, Message};
use paho_mqtt::{
AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder,
CreateOptionsBuilder, Message, SslOptionsBuilder,
AsyncReceiver, ConnectOptions, ConnectOptionsBuilder,
CreateOptionsBuilder, ServerResponse, SslOptionsBuilder,
};
use tokio_stream::StreamExt;
use tracing::{info, trace};
use crate::config::Configuration;
pub struct MqttClient {
client: AsyncClient,
stream: AsyncReceiver<Option<Message>>,
#[async_trait]
pub trait MessageHandler {
async fn navigate(
&mut self,
publisher: &MqttPublisher,
msg: &Message,
);
}
impl MqttClient {
pub async fn new(config: &Configuration) -> Result<MqttClient, Error> {
#[derive(Debug)]
pub enum MessageType {
Navigate,
}
pub struct MqttClient<'a> {
config: &'a Configuration,
client: AsyncClient,
stream: AsyncReceiver<Option<Message>>,
topics: TopicMatcher<MessageType>,
}
impl<'a> MqttClient<'a> {
pub fn new(config: &'a Configuration) -> Result<Self, Error> {
let uri = format!(
"{}://{}:{}",
if config.mqtt.tls { "ssl" } else { "tcp" },
@ -28,36 +46,98 @@ impl MqttClient {
CreateOptionsBuilder::new().server_uri(uri).finalize();
let mut client = AsyncClient::new(client_opts)?;
let stream = client.get_stream(10);
client.connect(Self::conn_opts(config)?).await?;
info!("Successfully connected to MQTT broker");
Ok(Self { client, stream })
let topics = TopicMatcher::new();
Ok(Self {
config,
client,
stream,
topics,
})
}
pub async fn run(mut self) {
pub async fn connect(&mut self) -> Result<ServerResponse, Error> {
let res = self.client.connect(self.conn_opts()?).await?;
info!("Successfully connected to MQTT broker");
Ok(res)
}
pub async fn subscribe(&mut self) -> Result<ServerResponse, Error> {
let prefix = &self.config.mqtt.topic_prefix;
let t_nav = format!("{}/+/navigate", prefix);
let res = self.client.subscribe(&t_nav, 0).await?;
self.topics.insert(t_nav, MessageType::Navigate);
Ok(res)
}
pub async fn run<H>(mut self, mut handler: H)
where
H: MessageHandler,
{
let publisher = MqttPublisher {
config: self.config,
client: &mut self.client,
};
while let Some(msg) = self.stream.next().await {
let Some(msg) = msg else {continue};
trace!("Received message: {:?}", msg);
for m in self.topics.matches(msg.topic()) {
match m.1 {
MessageType::Navigate => {
handler.navigate(&publisher, &msg).await;
}
}
}
}
}
fn conn_opts(config: &Configuration) -> Result<ConnectOptions, Error> {
fn conn_opts(&self) -> Result<ConnectOptions, Error> {
let mut conn_opts = ConnectOptionsBuilder::new();
conn_opts.automatic_reconnect(
Duration::from_millis(500),
Duration::from_secs(30),
);
if config.mqtt.tls {
if self.config.mqtt.tls {
let ssl_opts = SslOptionsBuilder::new()
.trust_store(&config.mqtt.ca_file)?
.trust_store(&self.config.mqtt.ca_file)?
.finalize();
conn_opts.ssl_options(ssl_opts);
}
if let [Some(username), Some(password)] =
[&config.mqtt.username, &config.mqtt.password]
[&self.config.mqtt.username, &self.config.mqtt.password]
{
conn_opts.user_name(username).password(password);
}
Ok(conn_opts.finalize())
}
}
pub struct MqttPublisher<'a> {
config: &'a Configuration,
client: &'a mut AsyncClient,
}
impl<'a> MqttPublisher<'a> {
pub async fn publish_title(
&self,
screen: &str,
title: &str,
) -> Result<(), Error> {
let topic =
format!("{}/{}/title", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, title, 0);
self.client.publish(msg).await?;
Ok(())
}
pub async fn publish_url(
&self,
screen: &str,
url: &str,
) -> Result<(), Error> {
let topic =
format!("{}/{}/url", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, url, 0);
self.client.publish(msg).await?;
Ok(())
}
}

View File

@ -1,12 +1,12 @@
use std::time::Duration;
use tracing::{debug, info, warn};
use tracing::{debug, error, info, warn};
use crate::mqtt::MqttClient;
use crate::browser::{Browser, BrowserError};
use crate::config::Configuration;
use crate::marionette::error::ConnectionError;
use crate::marionette::Marionette;
use crate::marionette::error::{CommandError, ConnectionError};
use crate::marionette::{Marionette, MarionetteConnection};
use crate::mqtt::{Message, MqttClient, MqttPublisher};
#[derive(Debug)]
pub enum SessionError {
@ -14,6 +14,7 @@ pub enum SessionError {
Io(std::io::Error),
Connection(ConnectionError),
InvalidState(String),
Command(CommandError),
}
impl From<BrowserError> for SessionError {
@ -34,6 +35,12 @@ impl From<ConnectionError> for SessionError {
}
}
impl From<CommandError> for SessionError {
fn from(e: CommandError) -> Self {
Self::Command(e)
}
}
impl std::fmt::Display for SessionError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
@ -41,6 +48,7 @@ impl std::fmt::Display for SessionError {
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Connection(e) => write!(f, "Connection error: {}", e),
Self::InvalidState(e) => write!(f, "Invalid state: {}", e),
Self::Command(e) => write!(f, "Marionette command failed: {}", e),
}
}
}
@ -52,6 +60,7 @@ impl std::error::Error for SessionError {
Self::Io(e) => Some(e),
Self::Connection(e) => Some(e),
Self::InvalidState(_) => None,
Self::Command(e) => Some(e),
}
}
}
@ -72,7 +81,8 @@ impl Session {
return Err(SessionError::InvalidState("No active Marionette port".into()));
};
debug!("Connecting to Firefox Marionette on port {}", port);
let mut marionette = Marionette::connect(("127.0.0.1", port)).await?;
let conn = MarionetteConnection::connect(("127.0.0.1", port)).await?;
let mut marionette = Marionette::new(conn);
info!("Successfully connected to Firefox Marionette");
let ses = marionette.new_session().await?;
debug!("Started Marionette session {}", ses.session_id);
@ -83,18 +93,64 @@ impl Session {
})
}
pub async fn run(&self) {
let client;
pub async fn run(mut self) {
let mut client = MqttClient::new(&self.config).unwrap();
loop {
match MqttClient::new(&self.config).await {
Ok(c) => {
client = c;
if let Err(e) = client.connect().await {
warn!("Failed to connect to MQTT server: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
if let Err(e) = client.subscribe().await {
warn!("Error subscribing to MQTT topics: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
break;
}
Err(e) => warn!("Failed to connect to MQTT server: {}", e),
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
client.run().await;
let handler = MessageHandler {
marionette: &mut self.marionette,
};
client.run(handler).await;
}
}
pub struct MessageHandler<'a> {
marionette: &'a mut Marionette,
}
#[async_trait::async_trait]
impl<'a> crate::mqtt::MessageHandler for MessageHandler<'a> {
async fn navigate(&mut self, publisher: &MqttPublisher, msg: &Message) {
let url = msg.payload_str();
let parts: Vec<&str> = msg.topic().split('/').rev().collect();
let screen = match parts.get(1) {
Some(&"") | None => {
warn!("Invalid navigate request: no screen");
return;
}
Some(s) => s,
};
debug!("Handling navigate request: {}", url);
info!("Navigate screen {} to {}", screen, url);
if let Err(e) = self.marionette.navigate(url.to_string()).await {
error!("Failed to navigate: {}", e);
}
match self.marionette.get_current_url().await {
Ok(u) => {
if let Err(e) = publisher.publish_url(screen, &u).await {
error!("Failed to publish URL: {}", e);
}
}
Err(e) => error!("Failed to get current browser URL: {}", e),
}
match self.marionette.get_title().await {
Ok(t) => {
if let Err(e) = publisher.publish_title(screen, &t).await {
error!("Failed to publish title: {}", e);
}
}
Err(e) => error!("Error getting title: {}", e),
}
}
}