Add basic MQTT client functionality

Naturally, we need a way to configure the MQTT connection parameters
(host, port, username, etc.).  For that, we'll use a TOML configuration
file, which is read at startup and deserialized into a structure owned
by the Session.

The Session object now has a `run` method, which establishes the MQTT
connection and then repeatedly waits for messages from the broker.  It
will continuously attempt to connect to the broker until it succeeds.
This way, if the broker is unavailable when the application starts, it
will eventually connect when it becomes available.  Once the initial
connection is established, the client will automatically reconnect if it
gets disconnected later.

Since the `run` method loops forever and never returns, we need to use a
separate Tokio task to manage it.  We keep the task handle so we can
cancel the task when the application shuts down.
dev/ci
Dustin 2022-12-30 13:49:01 -06:00
parent ce2d77a32c
commit ee8ed0c644
7 changed files with 209 additions and 3 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target
/config.toml

10
Cargo.lock generated
View File

@ -371,6 +371,7 @@ dependencies = [
"serde_json",
"tokio",
"tokio-stream",
"toml",
"tracing",
"tracing-subscriber",
]
@ -739,6 +740,15 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1333c76748e868a4d9d1017b5ab53171dfd095f70c712fdb4653a406547f598f"
dependencies = [
"serde",
]
[[package]]
name = "tracing"
version = "0.1.37"

View File

@ -12,5 +12,6 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.23.0", features = ["io-util", "macros", "net", "rt", "signal", "sync", "time"] }
tokio-stream = "0.1.11"
toml = "0.5.10"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] }

99
src/config.rs Normal file
View File

@ -0,0 +1,99 @@
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use serde::Deserialize;
#[derive(Debug)]
pub enum ConfigError {
Io(std::io::Error),
Toml(toml::de::Error),
}
impl From<std::io::Error> for ConfigError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<toml::de::Error> for ConfigError {
fn from(e: toml::de::Error) -> Self {
Self::Toml(e)
}
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "Could not read config file: {}", e),
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
}
}
}
impl std::error::Error for ConfigError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Toml(e) => Some(e),
}
}
}
#[derive(Debug, Deserialize)]
pub struct MqttConfig {
#[serde(default = "default_mqtt_host")]
pub host: String,
#[serde(default = "default_mqtt_port")]
pub port: u16,
#[serde(default)]
pub tls: bool,
#[serde(default)]
pub ca_file: PathBuf,
#[serde(default)]
pub username: Option<String>,
#[serde(default)]
pub password: Option<String>,
}
impl Default for MqttConfig {
fn default() -> Self {
Self {
host: default_mqtt_host(),
port: default_mqtt_port(),
tls: Default::default(),
ca_file: Default::default(),
username: Default::default(),
password: Default::default(),
}
}
}
#[derive(Debug, Default, Deserialize)]
pub struct Configuration {
pub mqtt: MqttConfig,
}
fn default_mqtt_host() -> String {
"localhost".into()
}
const fn default_mqtt_port() -> u16 {
1883
}
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
where
P: AsRef<Path>,
{
let path = match path {
Some(p) => PathBuf::from(p.as_ref()),
None => PathBuf::from("config.toml"),
};
match std::fs::read_to_string(path) {
Ok(s) => Ok(toml::from_str(&s)?),
Err(ref e) if e.kind() == ErrorKind::NotFound => {
Ok(Default::default())
}
Err(e) => Err(e.into()),
}
}

View File

@ -1,5 +1,7 @@
mod browser;
mod config;
mod marionette;
mod mqtt;
mod session;
use tokio::signal::unix::{self, SignalKind};
@ -15,13 +17,22 @@ async fn main() {
.with_writer(std::io::stderr)
.init();
let config =
config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok())
.unwrap();
let mut sig_term = unix::signal(SignalKind::terminate()).unwrap();
let mut sig_int = unix::signal(SignalKind::interrupt()).unwrap();
let session = Session::begin().await.unwrap();
let task = tokio::spawn(async move {
let session = Session::begin(config).await.unwrap();
session.run().await;
});
tokio::select! {
_ = sig_term.recv() => info!("Received SIGTERM"),
_ = sig_int.recv() => info!("Received SIGINT"),
};
task.abort();
}

63
src/mqtt.rs Normal file
View File

@ -0,0 +1,63 @@
use std::time::Duration;
pub use paho_mqtt::Error;
use paho_mqtt::{
AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder,
CreateOptionsBuilder, Message, SslOptionsBuilder,
};
use tokio_stream::StreamExt;
use tracing::{info, trace};
use crate::config::Configuration;
pub struct MqttClient {
client: AsyncClient,
stream: AsyncReceiver<Option<Message>>,
}
impl MqttClient {
pub async fn new(config: &Configuration) -> Result<MqttClient, Error> {
let uri = format!(
"{}://{}:{}",
if config.mqtt.tls { "ssl" } else { "tcp" },
config.mqtt.host,
config.mqtt.port
);
info!("Connecting to MQTT server {}", uri);
let client_opts =
CreateOptionsBuilder::new().server_uri(uri).finalize();
let mut client = AsyncClient::new(client_opts)?;
let stream = client.get_stream(10);
client.connect(Self::conn_opts(config)?).await?;
info!("Successfully connected to MQTT broker");
Ok(Self { client, stream })
}
pub async fn run(mut self) {
while let Some(msg) = self.stream.next().await {
let Some(msg) = msg else {continue};
trace!("Received message: {:?}", msg);
}
}
fn conn_opts(config: &Configuration) -> Result<ConnectOptions, Error> {
let mut conn_opts = ConnectOptionsBuilder::new();
conn_opts.automatic_reconnect(
Duration::from_millis(500),
Duration::from_secs(30),
);
if config.mqtt.tls {
let ssl_opts = SslOptionsBuilder::new()
.trust_store(&config.mqtt.ca_file)?
.finalize();
conn_opts.ssl_options(ssl_opts);
}
if let [Some(username), Some(password)] =
[&config.mqtt.username, &config.mqtt.password]
{
conn_opts.user_name(username).password(password);
}
Ok(conn_opts.finalize())
}
}

View File

@ -1,6 +1,10 @@
use tracing::{debug, info};
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::mqtt::MqttClient;
use crate::browser::{Browser, BrowserError};
use crate::config::Configuration;
use crate::marionette::error::ConnectionError;
use crate::marionette::Marionette;
@ -53,12 +57,13 @@ impl std::error::Error for SessionError {
}
pub struct Session {
config: Configuration,
browser: Browser,
marionette: Marionette,
}
impl Session {
pub async fn begin() -> Result<Self, SessionError> {
pub async fn begin(config: Configuration) -> Result<Self, SessionError> {
debug!("Launching Firefox");
let browser = Browser::launch()?;
browser.wait_ready().await?;
@ -72,8 +77,24 @@ impl Session {
let ses = marionette.new_session().await?;
debug!("Started Marionette session {}", ses.session_id);
Ok(Self {
config,
browser,
marionette,
})
}
pub async fn run(&self) {
let client;
loop {
match MqttClient::new(&self.config).await {
Ok(c) => {
client = c;
break;
}
Err(e) => warn!("Failed to connect to MQTT server: {}", e),
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
client.run().await;
}
}