diff --git a/src/mqtt.rs b/src/mqtt.rs index 179f47f..ab57ac6 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -10,7 +10,7 @@ use paho_mqtt::{ }; use tokio::sync::{Mutex, Notify}; use tokio_stream::StreamExt; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; use crate::config::Configuration; use crate::hass::{self, HassConfig}; @@ -30,13 +30,11 @@ pub struct MqttClient<'a> { client: Arc>, stream: AsyncReceiver>, topics: TopicMatcher, - stop: Arc, } impl<'a> MqttClient<'a> { pub fn new( config: &'a Configuration, - stop: Arc, ) -> Result { let uri = format!( "{}://{}:{}", @@ -58,7 +56,6 @@ impl<'a> MqttClient<'a> { client, stream, topics, - stop, }) } @@ -67,6 +64,7 @@ impl<'a> MqttClient<'a> { trace!("Connect options: {:?}", opts); let res = self.client.lock().await.connect(opts).await?; info!("Successfully connected to MQTT broker"); + self.on_connect().await; Ok(res) } @@ -85,7 +83,7 @@ impl<'a> MqttClient<'a> { } } - pub async fn run(mut self, mut handler: H) + pub async fn run(mut self, mut handler: H, stop: Arc) where H: MessageHandler, { @@ -94,16 +92,26 @@ impl<'a> MqttClient<'a> { client: self.client.clone(), }; let msg = self.offline_message(); + let client = self.client.clone(); tokio::spawn(async move { - self.stop.notified().await; - let client = self.client.lock().await; + stop.notified().await; + let client = client.lock().await; if let Err(e) = client.publish(msg).await { error!("Failed to publish offline message: {}", e); } + client.stop_stream(); client.disconnect(None); }); while let Some(msg) = self.stream.next().await { - let Some(msg) = msg else {break}; + let Some(msg) = msg else { + warn!("Lost connection to MQTT broker, reconnecting"); + while let Err(e) = self.client.lock().await.reconnect().await { + error!("Reconnect failed: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + } + self.on_connect().await; + continue; + }; trace!("Received message: {:?}", msg); for m in self.topics.matches(msg.topic()) { match m.1 { @@ -118,10 +126,6 @@ impl<'a> MqttClient<'a> { fn conn_opts(&self) -> Result { let mut conn_opts = ConnectOptionsBuilder::new(); conn_opts.will_message(self.offline_message()); - conn_opts.automatic_reconnect( - Duration::from_millis(500), - Duration::from_secs(30), - ); if self.config.mqtt.tls { let ssl_opts = SslOptionsBuilder::new() .trust_store(&self.config.mqtt.ca_file)? @@ -136,6 +140,25 @@ impl<'a> MqttClient<'a> { Ok(conn_opts.finalize()) } + async fn on_connect(&mut self) { + if let Err(e) = self.subscribe().await { + warn!("Error subscribing to MQTT topics: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + } + let client = self.client.lock().await; + if let Err(e) = client.publish(self.online_message()).await { + error!("Failed to publish availability message: {}", e); + } + } + + fn online_message(&self) -> Message { + Message::new_retained( + format!("{}/available", self.config.mqtt.topic_prefix), + "online", + 0, + ) + } + fn offline_message(&self) -> Message { Message::new_retained( format!("{}/available", self.config.mqtt.topic_prefix), @@ -228,12 +251,10 @@ impl<'a> MqttPublisher<'a> { serde_json::to_string(&config).unwrap(), 0, ); - trace!("Publishing message: {:?}", msg); self.client.lock().await.publish(msg).await?; info!("Succesfully published Home Assistant config"); Ok(()) } - } diff --git a/src/session.rs b/src/session.rs index b4ba55a..3f4d3a9 100644 --- a/src/session.rs +++ b/src/session.rs @@ -109,18 +109,13 @@ impl Session { } }; - let mut client = MqttClient::new(&self.config, stop).unwrap(); + let mut client = MqttClient::new(&self.config).unwrap(); loop { if let Err(e) = client.connect().await { warn!("Failed to connect to MQTT server: {}", e); tokio::time::sleep(Duration::from_secs(1)).await; continue; } - if let Err(e) = client.subscribe().await { - warn!("Error subscribing to MQTT topics: {}", e); - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } break; } @@ -138,7 +133,7 @@ impl Session { marionette: &mut self.marionette, windows, }; - client.run(handler).await; + client.run(handler, stop).await; } async fn init_windows(