From ee8ed0c6443beb1ebfea308eccf128fac9a47e71 Mon Sep 17 00:00:00 2001 From: "Dustin C. Hatch" Date: Fri, 30 Dec 2022 13:49:01 -0600 Subject: [PATCH] Add basic MQTT client functionality Naturally, we need a way to configure the MQTT connection parameters (host, port, username, etc.). For that, we'll use a TOML configuration file, which is read at startup and deserialized into a structure owned by the Session. The Session object now has a `run` method, which establishes the MQTT connection and then repeatedly waits for messages from the broker. It will continuously attempt to connect to the broker until it succeeds. This way, if the broker is unavailable when the application starts, it will eventually connect when it becomes available. Once the initial connection is established, the client will automatically reconnect if it gets disconnected later. Since the `run` method loops forever and never returns, we need to use a separate Tokio task to manage it. We keep the task handle so we can cancel the task when the application shuts down. --- .gitignore | 1 + Cargo.lock | 10 +++++ Cargo.toml | 1 + src/config.rs | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 13 ++++++- src/mqtt.rs | 63 ++++++++++++++++++++++++++++++++ src/session.rs | 25 ++++++++++++- 7 files changed, 209 insertions(+), 3 deletions(-) create mode 100644 src/config.rs create mode 100644 src/mqtt.rs diff --git a/.gitignore b/.gitignore index ea8c4bf..fabfb87 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/config.toml diff --git a/Cargo.lock b/Cargo.lock index 6de7c88..097d7ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,6 +371,7 @@ dependencies = [ "serde_json", "tokio", "tokio-stream", + "toml", "tracing", "tracing-subscriber", ] @@ -739,6 +740,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1333c76748e868a4d9d1017b5ab53171dfd095f70c712fdb4653a406547f598f" +dependencies = [ + "serde", +] + [[package]] name = "tracing" version = "0.1.37" diff --git a/Cargo.toml b/Cargo.toml index cb4ac25..d6e7262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,6 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.23.0", features = ["io-util", "macros", "net", "rt", "signal", "sync", "time"] } tokio-stream = "0.1.11" +toml = "0.5.10" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..563270f --- /dev/null +++ b/src/config.rs @@ -0,0 +1,99 @@ +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +#[derive(Debug)] +pub enum ConfigError { + Io(std::io::Error), + Toml(toml::de::Error), +} + +impl From for ConfigError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +impl From for ConfigError { + fn from(e: toml::de::Error) -> Self { + Self::Toml(e) + } +} + +impl std::fmt::Display for ConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Io(e) => write!(f, "Could not read config file: {}", e), + Self::Toml(e) => write!(f, "Could not parse config: {}", e), + } + } +} + +impl std::error::Error for ConfigError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + Self::Toml(e) => Some(e), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct MqttConfig { + #[serde(default = "default_mqtt_host")] + pub host: String, + #[serde(default = "default_mqtt_port")] + pub port: u16, + #[serde(default)] + pub tls: bool, + #[serde(default)] + pub ca_file: PathBuf, + #[serde(default)] + pub username: Option, + #[serde(default)] + pub password: Option, +} + +impl Default for MqttConfig { + fn default() -> Self { + Self { + host: default_mqtt_host(), + port: default_mqtt_port(), + tls: Default::default(), + ca_file: Default::default(), + username: Default::default(), + password: Default::default(), + } + } +} + +#[derive(Debug, Default, Deserialize)] +pub struct Configuration { + pub mqtt: MqttConfig, +} + +fn default_mqtt_host() -> String { + "localhost".into() +} + +const fn default_mqtt_port() -> u16 { + 1883 +} + +pub fn load_config

(path: Option

) -> Result +where + P: AsRef, +{ + let path = match path { + Some(p) => PathBuf::from(p.as_ref()), + None => PathBuf::from("config.toml"), + }; + match std::fs::read_to_string(path) { + Ok(s) => Ok(toml::from_str(&s)?), + Err(ref e) if e.kind() == ErrorKind::NotFound => { + Ok(Default::default()) + } + Err(e) => Err(e.into()), + } +} diff --git a/src/main.rs b/src/main.rs index 634b78d..2f2b86d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,7 @@ mod browser; +mod config; mod marionette; +mod mqtt; mod session; use tokio::signal::unix::{self, SignalKind}; @@ -15,13 +17,22 @@ async fn main() { .with_writer(std::io::stderr) .init(); + let config = + config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok()) + .unwrap(); + let mut sig_term = unix::signal(SignalKind::terminate()).unwrap(); let mut sig_int = unix::signal(SignalKind::interrupt()).unwrap(); - let session = Session::begin().await.unwrap(); + let task = tokio::spawn(async move { + let session = Session::begin(config).await.unwrap(); + session.run().await; + }); tokio::select! { _ = sig_term.recv() => info!("Received SIGTERM"), _ = sig_int.recv() => info!("Received SIGINT"), }; + + task.abort(); } diff --git a/src/mqtt.rs b/src/mqtt.rs new file mode 100644 index 0000000..a22fa3c --- /dev/null +++ b/src/mqtt.rs @@ -0,0 +1,63 @@ +use std::time::Duration; + +pub use paho_mqtt::Error; +use paho_mqtt::{ + AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, + CreateOptionsBuilder, Message, SslOptionsBuilder, +}; +use tokio_stream::StreamExt; +use tracing::{info, trace}; + +use crate::config::Configuration; + +pub struct MqttClient { + client: AsyncClient, + stream: AsyncReceiver>, +} + +impl MqttClient { + pub async fn new(config: &Configuration) -> Result { + let uri = format!( + "{}://{}:{}", + if config.mqtt.tls { "ssl" } else { "tcp" }, + config.mqtt.host, + config.mqtt.port + ); + info!("Connecting to MQTT server {}", uri); + let client_opts = + CreateOptionsBuilder::new().server_uri(uri).finalize(); + let mut client = AsyncClient::new(client_opts)?; + let stream = client.get_stream(10); + client.connect(Self::conn_opts(config)?).await?; + info!("Successfully connected to MQTT broker"); + + Ok(Self { client, stream }) + } + + pub async fn run(mut self) { + while let Some(msg) = self.stream.next().await { + let Some(msg) = msg else {continue}; + trace!("Received message: {:?}", msg); + } + } + + fn conn_opts(config: &Configuration) -> Result { + let mut conn_opts = ConnectOptionsBuilder::new(); + conn_opts.automatic_reconnect( + Duration::from_millis(500), + Duration::from_secs(30), + ); + if config.mqtt.tls { + let ssl_opts = SslOptionsBuilder::new() + .trust_store(&config.mqtt.ca_file)? + .finalize(); + conn_opts.ssl_options(ssl_opts); + } + if let [Some(username), Some(password)] = + [&config.mqtt.username, &config.mqtt.password] + { + conn_opts.user_name(username).password(password); + } + Ok(conn_opts.finalize()) + } +} diff --git a/src/session.rs b/src/session.rs index a081a66..2883126 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,10 @@ -use tracing::{debug, info}; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::mqtt::MqttClient; use crate::browser::{Browser, BrowserError}; +use crate::config::Configuration; use crate::marionette::error::ConnectionError; use crate::marionette::Marionette; @@ -53,12 +57,13 @@ impl std::error::Error for SessionError { } pub struct Session { + config: Configuration, browser: Browser, marionette: Marionette, } impl Session { - pub async fn begin() -> Result { + pub async fn begin(config: Configuration) -> Result { debug!("Launching Firefox"); let browser = Browser::launch()?; browser.wait_ready().await?; @@ -72,8 +77,24 @@ impl Session { let ses = marionette.new_session().await?; debug!("Started Marionette session {}", ses.session_id); Ok(Self { + config, browser, marionette, }) } + + pub async fn run(&self) { + let client; + loop { + match MqttClient::new(&self.config).await { + Ok(c) => { + client = c; + break; + } + Err(e) => warn!("Failed to connect to MQTT server: {}", e), + } + tokio::time::sleep(Duration::from_secs(1)).await; + } + client.run().await; + } }