From 94a47d863ca7848841d23dc2cac5d652ab87a897 Mon Sep 17 00:00:00 2001 From: "Dustin C. Hatch" Date: Sat, 7 Jan 2023 22:00:59 -0600 Subject: [PATCH] Send offline message on shutdown The MQTT broker does *not* send the client's last will and testament message when the client disconnects gracefully. I thought it did because my other Rust/MQTT project, MQTTDPMS, seems to behave that way. It turns out, though, that in that project, the client never actually disconnects gracefully. It has no signal handlers, so when it receives SIGINT or SIGTERM, it just exits immediately. This leaves the OS to forcefully close the TCP connection, so the broker sends the will message. Since the Browser HUD process *does* have signal handlers, when it receives a signal, it shuts down gracefully. As Rust drops objects during shut down, the MQTT client eventually disconnects cleanly, so the broker does not send the will message. In order to notify Home Assistant that the device is now unavailable, we have to explicitly send the offline message before disconnecting the MQTT client. I've added a `Notify` object that lives for the entire life of the process and is passed in to the session. When a signal is received, this object wakes up the asynchronous tasks that perform the pre-shutdown operations. One such task is spawned by the `MqttClient::run` method; it sends the offline message when notified, then disconnects the MQTT client. In order to share the MQTT client object between this new task and the message receive loop, it has to be wrapped in an `Arc` and a `Mutex`. --- src/main.rs | 11 ++++++++-- src/mqtt.rs | 58 ++++++++++++++++++++++++++++++++++---------------- src/session.rs | 6 ++++-- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/main.rs b/src/main.rs index a5c877f..25ae572 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ mod util; #[cfg(unix)] mod x11; +use std::sync::Arc; + use tokio::signal::unix::{self, SignalKind}; +use tokio::sync::Notify; use tracing::info; use tracing_subscriber::filter::EnvFilter; @@ -26,14 +29,18 @@ async fn main() { config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok()) .unwrap(); + let stop = Arc::new(Notify::new()); + let stopwait = stop.clone(); + let task = tokio::spawn(async move { let session = Session::begin(config).await.unwrap(); - session.run().await; + session.run(stopwait).await; }); wait_signal().await; - task.abort(); + stop.notify_waiters(); + task.await; } #[cfg(unix)] diff --git a/src/mqtt.rs b/src/mqtt.rs index 0bb5013..6c8e6c7 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; @@ -7,8 +8,9 @@ use paho_mqtt::{ AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, CreateOptionsBuilder, ServerResponse, SslOptionsBuilder, }; +use tokio::sync::{Mutex, Notify}; use tokio_stream::StreamExt; -use tracing::{debug, info, trace}; +use tracing::{debug, error, info, trace}; use crate::config::Configuration; use crate::hass::{self, HassConfig}; @@ -25,13 +27,17 @@ pub enum MessageType { pub struct MqttClient<'a> { config: &'a Configuration, - client: AsyncClient, + client: Arc>, stream: AsyncReceiver>, topics: TopicMatcher, + stop: Arc, } impl<'a> MqttClient<'a> { - pub fn new(config: &'a Configuration) -> Result { + pub fn new( + config: &'a Configuration, + stop: Arc, + ) -> Result { let uri = format!( "{}://{}:{}", if config.mqtt.tls { "ssl" } else { "tcp" }, @@ -45,19 +51,21 @@ impl<'a> MqttClient<'a> { .finalize(); let mut client = AsyncClient::new(client_opts)?; let stream = client.get_stream(10); + let client = Arc::new(Mutex::new(client)); let topics = TopicMatcher::new(); Ok(Self { config, client, stream, topics, + stop, }) } pub async fn connect(&mut self) -> Result { let opts = self.conn_opts()?; trace!("Connect options: {:?}", opts); - let res = self.client.connect(opts).await?; + let res = self.client.lock().await.connect(opts).await?; info!("Successfully connected to MQTT broker"); Ok(res) } @@ -65,7 +73,7 @@ impl<'a> MqttClient<'a> { pub async fn subscribe(&mut self) -> Result { let prefix = &self.config.mqtt.topic_prefix; let t_nav = format!("{}/+/navigate", prefix); - let res = self.client.subscribe(&t_nav, 0).await?; + let res = self.client.lock().await.subscribe(&t_nav, 0).await?; self.topics.insert(t_nav, MessageType::Navigate); Ok(res) } @@ -73,7 +81,7 @@ impl<'a> MqttClient<'a> { pub fn publisher(&mut self) -> MqttPublisher { MqttPublisher { config: self.config, - client: &mut self.client, + client: self.client.clone(), } } @@ -83,8 +91,17 @@ impl<'a> MqttClient<'a> { { let publisher = MqttPublisher { config: self.config, - client: &mut self.client, + client: self.client.clone(), }; + let msg = self.offline_message(); + tokio::spawn(async move { + self.stop.notified().await; + let client = self.client.lock().await; + if let Err(e) = client.publish(msg).await { + error!("Failed to publish offline message: {}", e); + } + client.disconnect(None); + }); while let Some(msg) = self.stream.next().await { let Some(msg) = msg else {continue}; trace!("Received message: {:?}", msg); @@ -100,11 +117,7 @@ impl<'a> MqttClient<'a> { fn conn_opts(&self) -> Result { let mut conn_opts = ConnectOptionsBuilder::new(); - conn_opts.will_message(Message::new_retained( - format!("{}/available", self.config.mqtt.topic_prefix), - "offline", - 0, - )); + conn_opts.will_message(self.offline_message()); conn_opts.automatic_reconnect( Duration::from_millis(500), Duration::from_secs(30), @@ -122,11 +135,19 @@ impl<'a> MqttClient<'a> { } Ok(conn_opts.finalize()) } + + fn offline_message(&self) -> Message { + Message::new_retained( + format!("{}/available", self.config.mqtt.topic_prefix), + "offline", + 0, + ) + } } pub struct MqttPublisher<'a> { config: &'a Configuration, - client: &'a mut AsyncClient, + client: Arc>, } impl<'a> MqttPublisher<'a> { @@ -138,7 +159,7 @@ impl<'a> MqttPublisher<'a> { let topic = format!("{}/{}/title", self.config.mqtt.topic_prefix, screen); let msg = Message::new_retained(topic, title, 0); - self.client.publish(msg).await?; + (*self.client.lock().await).publish(msg).await?; Ok(()) } @@ -150,7 +171,7 @@ impl<'a> MqttPublisher<'a> { let topic = format!("{}/{}/url", self.config.mqtt.topic_prefix, screen); let msg = Message::new_retained(topic, url, 0); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; Ok(()) } @@ -171,7 +192,7 @@ impl<'a> MqttPublisher<'a> { let object_id = unique_id.clone(); let msg = Message::new_retained(&availability_topic, "online", 0); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; let config = HassConfig { availability_topic, command_topic, @@ -189,7 +210,7 @@ impl<'a> MqttPublisher<'a> { 0, ); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; let unique_id = format!("sensor.{}_title", key); let object_id = unique_id.clone(); @@ -209,9 +230,10 @@ impl<'a> MqttPublisher<'a> { ); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; info!("Succesfully published Home Assistant config"); Ok(()) } + } diff --git a/src/session.rs b/src/session.rs index eae20f2..b4ba55a 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::Notify; use tracing::{debug, error, info, trace, warn}; use crate::browser::{Browser, BrowserError}; @@ -98,7 +100,7 @@ impl Session { }) } - pub async fn run(mut self) { + pub async fn run(mut self, stop: Arc) { let windows = match self.init_windows().await { Ok(w) => w, Err(e) => { @@ -107,7 +109,7 @@ impl Session { } }; - let mut client = MqttClient::new(&self.config).unwrap(); + let mut client = MqttClient::new(&self.config, stop).unwrap(); loop { if let Err(e) = client.connect().await { warn!("Failed to connect to MQTT server: {}", e);