mqtt: Manually handle reconnects

The Paho automatic reconnect capability is pretty useless.  While it
does automatically connect to the broker again after an unexpected
disconnect, it does not subscribe to the topics again.  Further, since
the broker will send the will message when the client disconnects
unexpectedly, we need to send our "online" availability message when we
reconnect.

To resolve both of these problems, the `MqttClient::run` method now
takes care reconnecting manually.  When it receives a disconnect
notification, it explicitly calls `AsyncClient::reconnect`.  Once that
succeeds, it resubscribes to the command topics and publishes an
"online" message.
dev/ci
Dustin 2023-01-08 12:54:32 -06:00
parent 6e8ba8b7b6
commit a2acdfd0dc
2 changed files with 37 additions and 21 deletions

View File

@ -10,7 +10,7 @@ use paho_mqtt::{
}; };
use tokio::sync::{Mutex, Notify}; use tokio::sync::{Mutex, Notify};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{debug, error, info, trace}; use tracing::{debug, error, info, trace, warn};
use crate::config::Configuration; use crate::config::Configuration;
use crate::hass::{self, HassConfig}; use crate::hass::{self, HassConfig};
@ -30,13 +30,11 @@ pub struct MqttClient<'a> {
client: Arc<Mutex<AsyncClient>>, client: Arc<Mutex<AsyncClient>>,
stream: AsyncReceiver<Option<Message>>, stream: AsyncReceiver<Option<Message>>,
topics: TopicMatcher<MessageType>, topics: TopicMatcher<MessageType>,
stop: Arc<Notify>,
} }
impl<'a> MqttClient<'a> { impl<'a> MqttClient<'a> {
pub fn new( pub fn new(
config: &'a Configuration, config: &'a Configuration,
stop: Arc<Notify>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let uri = format!( let uri = format!(
"{}://{}:{}", "{}://{}:{}",
@ -58,7 +56,6 @@ impl<'a> MqttClient<'a> {
client, client,
stream, stream,
topics, topics,
stop,
}) })
} }
@ -67,6 +64,7 @@ impl<'a> MqttClient<'a> {
trace!("Connect options: {:?}", opts); trace!("Connect options: {:?}", opts);
let res = self.client.lock().await.connect(opts).await?; let res = self.client.lock().await.connect(opts).await?;
info!("Successfully connected to MQTT broker"); info!("Successfully connected to MQTT broker");
self.on_connect().await;
Ok(res) Ok(res)
} }
@ -85,7 +83,7 @@ impl<'a> MqttClient<'a> {
} }
} }
pub async fn run<H>(mut self, mut handler: H) pub async fn run<H>(mut self, mut handler: H, stop: Arc<Notify>)
where where
H: MessageHandler, H: MessageHandler,
{ {
@ -94,16 +92,26 @@ impl<'a> MqttClient<'a> {
client: self.client.clone(), client: self.client.clone(),
}; };
let msg = self.offline_message(); let msg = self.offline_message();
let client = self.client.clone();
tokio::spawn(async move { tokio::spawn(async move {
self.stop.notified().await; stop.notified().await;
let client = self.client.lock().await; let client = client.lock().await;
if let Err(e) = client.publish(msg).await { if let Err(e) = client.publish(msg).await {
error!("Failed to publish offline message: {}", e); error!("Failed to publish offline message: {}", e);
} }
client.stop_stream();
client.disconnect(None); client.disconnect(None);
}); });
while let Some(msg) = self.stream.next().await { while let Some(msg) = self.stream.next().await {
let Some(msg) = msg else {break}; let Some(msg) = msg else {
warn!("Lost connection to MQTT broker, reconnecting");
while let Err(e) = self.client.lock().await.reconnect().await {
error!("Reconnect failed: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
}
self.on_connect().await;
continue;
};
trace!("Received message: {:?}", msg); trace!("Received message: {:?}", msg);
for m in self.topics.matches(msg.topic()) { for m in self.topics.matches(msg.topic()) {
match m.1 { match m.1 {
@ -118,10 +126,6 @@ impl<'a> MqttClient<'a> {
fn conn_opts(&self) -> Result<ConnectOptions, Error> { fn conn_opts(&self) -> Result<ConnectOptions, Error> {
let mut conn_opts = ConnectOptionsBuilder::new(); let mut conn_opts = ConnectOptionsBuilder::new();
conn_opts.will_message(self.offline_message()); conn_opts.will_message(self.offline_message());
conn_opts.automatic_reconnect(
Duration::from_millis(500),
Duration::from_secs(30),
);
if self.config.mqtt.tls { if self.config.mqtt.tls {
let ssl_opts = SslOptionsBuilder::new() let ssl_opts = SslOptionsBuilder::new()
.trust_store(&self.config.mqtt.ca_file)? .trust_store(&self.config.mqtt.ca_file)?
@ -136,6 +140,25 @@ impl<'a> MqttClient<'a> {
Ok(conn_opts.finalize()) Ok(conn_opts.finalize())
} }
async fn on_connect(&mut self) {
if let Err(e) = self.subscribe().await {
warn!("Error subscribing to MQTT topics: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
}
let client = self.client.lock().await;
if let Err(e) = client.publish(self.online_message()).await {
error!("Failed to publish availability message: {}", e);
}
}
fn online_message(&self) -> Message {
Message::new_retained(
format!("{}/available", self.config.mqtt.topic_prefix),
"online",
0,
)
}
fn offline_message(&self) -> Message { fn offline_message(&self) -> Message {
Message::new_retained( Message::new_retained(
format!("{}/available", self.config.mqtt.topic_prefix), format!("{}/available", self.config.mqtt.topic_prefix),
@ -228,12 +251,10 @@ impl<'a> MqttPublisher<'a> {
serde_json::to_string(&config).unwrap(), serde_json::to_string(&config).unwrap(),
0, 0,
); );
trace!("Publishing message: {:?}", msg); trace!("Publishing message: {:?}", msg);
self.client.lock().await.publish(msg).await?; self.client.lock().await.publish(msg).await?;
info!("Succesfully published Home Assistant config"); info!("Succesfully published Home Assistant config");
Ok(()) Ok(())
} }
} }

View File

@ -109,18 +109,13 @@ impl Session {
} }
}; };
let mut client = MqttClient::new(&self.config, stop).unwrap(); let mut client = MqttClient::new(&self.config).unwrap();
loop { loop {
if let Err(e) = client.connect().await { if let Err(e) = client.connect().await {
warn!("Failed to connect to MQTT server: {}", e); warn!("Failed to connect to MQTT server: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
continue; continue;
} }
if let Err(e) = client.subscribe().await {
warn!("Error subscribing to MQTT topics: {}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
break; break;
} }
@ -138,7 +133,7 @@ impl Session {
marionette: &mut self.marionette, marionette: &mut self.marionette,
windows, windows,
}; };
client.run(handler).await; client.run(handler, stop).await;
} }
async fn init_windows( async fn init_windows(