Send offline message on shutdown

The MQTT broker does *not* send the client's last will and testament
message when the client disconnects gracefully.  I thought it did
because my other Rust/MQTT project, MQTTDPMS, seems to behave that way.
It turns out, though, that in that project, the client never actually
disconnects gracefully.  It has no signal handlers, so when it receives
SIGINT or SIGTERM, it just exits immediately.  This leaves the OS to
forcefully close the TCP connection, so the broker sends the will
message.

Since the Browser HUD process *does* have signal handlers, when it
receives a signal, it shuts down gracefully.  As Rust drops objects
during shut down, the MQTT client eventually disconnects cleanly, so the
broker does not send the will message.  In order to notify Home
Assistant that the device is now unavailable, we have to explicitly send
the offline message before disconnecting the MQTT client.

I've added a `Notify` object that lives for the entire life of the
process and is passed in to the session.  When a signal is received,
this object wakes up the asynchronous tasks that perform the
pre-shutdown operations.  One such task is spawned by the
`MqttClient::run` method; it sends the offline message when notified,
then disconnects the MQTT client.  In order to share the MQTT client
object between this new task and the message receive loop, it has to be
wrapped in an `Arc` and a `Mutex`.
dev/ci
Dustin 2023-01-07 22:00:59 -06:00
parent d4f2c73eca
commit 94a47d863c
3 changed files with 53 additions and 22 deletions

View File

@ -9,7 +9,10 @@ mod util;
#[cfg(unix)] #[cfg(unix)]
mod x11; mod x11;
use std::sync::Arc;
use tokio::signal::unix::{self, SignalKind}; use tokio::signal::unix::{self, SignalKind};
use tokio::sync::Notify;
use tracing::info; use tracing::info;
use tracing_subscriber::filter::EnvFilter; use tracing_subscriber::filter::EnvFilter;
@ -26,14 +29,18 @@ async fn main() {
config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok()) config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok())
.unwrap(); .unwrap();
let stop = Arc::new(Notify::new());
let stopwait = stop.clone();
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let session = Session::begin(config).await.unwrap(); let session = Session::begin(config).await.unwrap();
session.run().await; session.run(stopwait).await;
}); });
wait_signal().await; wait_signal().await;
task.abort(); stop.notify_waiters();
task.await;
} }
#[cfg(unix)] #[cfg(unix)]

View File

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
@ -7,8 +8,9 @@ use paho_mqtt::{
AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder,
CreateOptionsBuilder, ServerResponse, SslOptionsBuilder, CreateOptionsBuilder, ServerResponse, SslOptionsBuilder,
}; };
use tokio::sync::{Mutex, Notify};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{debug, info, trace}; use tracing::{debug, error, info, trace};
use crate::config::Configuration; use crate::config::Configuration;
use crate::hass::{self, HassConfig}; use crate::hass::{self, HassConfig};
@ -25,13 +27,17 @@ pub enum MessageType {
pub struct MqttClient<'a> { pub struct MqttClient<'a> {
config: &'a Configuration, config: &'a Configuration,
client: AsyncClient, client: Arc<Mutex<AsyncClient>>,
stream: AsyncReceiver<Option<Message>>, stream: AsyncReceiver<Option<Message>>,
topics: TopicMatcher<MessageType>, topics: TopicMatcher<MessageType>,
stop: Arc<Notify>,
} }
impl<'a> MqttClient<'a> { impl<'a> MqttClient<'a> {
pub fn new(config: &'a Configuration) -> Result<Self, Error> { pub fn new(
config: &'a Configuration,
stop: Arc<Notify>,
) -> Result<Self, Error> {
let uri = format!( let uri = format!(
"{}://{}:{}", "{}://{}:{}",
if config.mqtt.tls { "ssl" } else { "tcp" }, if config.mqtt.tls { "ssl" } else { "tcp" },
@ -45,19 +51,21 @@ impl<'a> MqttClient<'a> {
.finalize(); .finalize();
let mut client = AsyncClient::new(client_opts)?; let mut client = AsyncClient::new(client_opts)?;
let stream = client.get_stream(10); let stream = client.get_stream(10);
let client = Arc::new(Mutex::new(client));
let topics = TopicMatcher::new(); let topics = TopicMatcher::new();
Ok(Self { Ok(Self {
config, config,
client, client,
stream, stream,
topics, topics,
stop,
}) })
} }
pub async fn connect(&mut self) -> Result<ServerResponse, Error> { pub async fn connect(&mut self) -> Result<ServerResponse, Error> {
let opts = self.conn_opts()?; let opts = self.conn_opts()?;
trace!("Connect options: {:?}", opts); trace!("Connect options: {:?}", opts);
let res = self.client.connect(opts).await?; let res = self.client.lock().await.connect(opts).await?;
info!("Successfully connected to MQTT broker"); info!("Successfully connected to MQTT broker");
Ok(res) Ok(res)
} }
@ -65,7 +73,7 @@ impl<'a> MqttClient<'a> {
pub async fn subscribe(&mut self) -> Result<ServerResponse, Error> { pub async fn subscribe(&mut self) -> Result<ServerResponse, Error> {
let prefix = &self.config.mqtt.topic_prefix; let prefix = &self.config.mqtt.topic_prefix;
let t_nav = format!("{}/+/navigate", prefix); let t_nav = format!("{}/+/navigate", prefix);
let res = self.client.subscribe(&t_nav, 0).await?; let res = self.client.lock().await.subscribe(&t_nav, 0).await?;
self.topics.insert(t_nav, MessageType::Navigate); self.topics.insert(t_nav, MessageType::Navigate);
Ok(res) Ok(res)
} }
@ -73,7 +81,7 @@ impl<'a> MqttClient<'a> {
pub fn publisher(&mut self) -> MqttPublisher { pub fn publisher(&mut self) -> MqttPublisher {
MqttPublisher { MqttPublisher {
config: self.config, config: self.config,
client: &mut self.client, client: self.client.clone(),
} }
} }
@ -83,8 +91,17 @@ impl<'a> MqttClient<'a> {
{ {
let publisher = MqttPublisher { let publisher = MqttPublisher {
config: self.config, config: self.config,
client: &mut self.client, client: self.client.clone(),
}; };
let msg = self.offline_message();
tokio::spawn(async move {
self.stop.notified().await;
let client = self.client.lock().await;
if let Err(e) = client.publish(msg).await {
error!("Failed to publish offline message: {}", e);
}
client.disconnect(None);
});
while let Some(msg) = self.stream.next().await { while let Some(msg) = self.stream.next().await {
let Some(msg) = msg else {continue}; let Some(msg) = msg else {continue};
trace!("Received message: {:?}", msg); trace!("Received message: {:?}", msg);
@ -100,11 +117,7 @@ impl<'a> MqttClient<'a> {
fn conn_opts(&self) -> Result<ConnectOptions, Error> { fn conn_opts(&self) -> Result<ConnectOptions, Error> {
let mut conn_opts = ConnectOptionsBuilder::new(); let mut conn_opts = ConnectOptionsBuilder::new();
conn_opts.will_message(Message::new_retained( conn_opts.will_message(self.offline_message());
format!("{}/available", self.config.mqtt.topic_prefix),
"offline",
0,
));
conn_opts.automatic_reconnect( conn_opts.automatic_reconnect(
Duration::from_millis(500), Duration::from_millis(500),
Duration::from_secs(30), Duration::from_secs(30),
@ -122,11 +135,19 @@ impl<'a> MqttClient<'a> {
} }
Ok(conn_opts.finalize()) Ok(conn_opts.finalize())
} }
fn offline_message(&self) -> Message {
Message::new_retained(
format!("{}/available", self.config.mqtt.topic_prefix),
"offline",
0,
)
}
} }
pub struct MqttPublisher<'a> { pub struct MqttPublisher<'a> {
config: &'a Configuration, config: &'a Configuration,
client: &'a mut AsyncClient, client: Arc<tokio::sync::Mutex<AsyncClient>>,
} }
impl<'a> MqttPublisher<'a> { impl<'a> MqttPublisher<'a> {
@ -138,7 +159,7 @@ impl<'a> MqttPublisher<'a> {
let topic = let topic =
format!("{}/{}/title", self.config.mqtt.topic_prefix, screen); format!("{}/{}/title", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, title, 0); let msg = Message::new_retained(topic, title, 0);
self.client.publish(msg).await?; (*self.client.lock().await).publish(msg).await?;
Ok(()) Ok(())
} }
@ -150,7 +171,7 @@ impl<'a> MqttPublisher<'a> {
let topic = let topic =
format!("{}/{}/url", self.config.mqtt.topic_prefix, screen); format!("{}/{}/url", self.config.mqtt.topic_prefix, screen);
let msg = Message::new_retained(topic, url, 0); let msg = Message::new_retained(topic, url, 0);
self.client.publish(msg).await?; self.client.lock().await.publish(msg).await?;
Ok(()) Ok(())
} }
@ -171,7 +192,7 @@ impl<'a> MqttPublisher<'a> {
let object_id = unique_id.clone(); let object_id = unique_id.clone();
let msg = Message::new_retained(&availability_topic, "online", 0); let msg = Message::new_retained(&availability_topic, "online", 0);
trace!("Publishing message: {:?}", msg); trace!("Publishing message: {:?}", msg);
self.client.publish(msg).await?; self.client.lock().await.publish(msg).await?;
let config = HassConfig { let config = HassConfig {
availability_topic, availability_topic,
command_topic, command_topic,
@ -189,7 +210,7 @@ impl<'a> MqttPublisher<'a> {
0, 0,
); );
trace!("Publishing message: {:?}", msg); trace!("Publishing message: {:?}", msg);
self.client.publish(msg).await?; self.client.lock().await.publish(msg).await?;
let unique_id = format!("sensor.{}_title", key); let unique_id = format!("sensor.{}_title", key);
let object_id = unique_id.clone(); let object_id = unique_id.clone();
@ -209,9 +230,10 @@ impl<'a> MqttPublisher<'a> {
); );
trace!("Publishing message: {:?}", msg); trace!("Publishing message: {:?}", msg);
self.client.publish(msg).await?; self.client.lock().await.publish(msg).await?;
info!("Succesfully published Home Assistant config"); info!("Succesfully published Home Assistant config");
Ok(()) Ok(())
} }
} }

View File

@ -1,6 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::Notify;
use tracing::{debug, error, info, trace, warn}; use tracing::{debug, error, info, trace, warn};
use crate::browser::{Browser, BrowserError}; use crate::browser::{Browser, BrowserError};
@ -98,7 +100,7 @@ impl Session {
}) })
} }
pub async fn run(mut self) { pub async fn run(mut self, stop: Arc<Notify>) {
let windows = match self.init_windows().await { let windows = match self.init_windows().await {
Ok(w) => w, Ok(w) => w,
Err(e) => { Err(e) => {
@ -107,7 +109,7 @@ impl Session {
} }
}; };
let mut client = MqttClient::new(&self.config).unwrap(); let mut client = MqttClient::new(&self.config, stop).unwrap();
loop { loop {
if let Err(e) = client.connect().await { if let Err(e) = client.connect().await {
warn!("Failed to connect to MQTT server: {}", e); warn!("Failed to connect to MQTT server: {}", e);