diff --git a/src/main.rs b/src/main.rs index a5c877f..25ae572 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ mod util; #[cfg(unix)] mod x11; +use std::sync::Arc; + use tokio::signal::unix::{self, SignalKind}; +use tokio::sync::Notify; use tracing::info; use tracing_subscriber::filter::EnvFilter; @@ -26,14 +29,18 @@ async fn main() { config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok()) .unwrap(); + let stop = Arc::new(Notify::new()); + let stopwait = stop.clone(); + let task = tokio::spawn(async move { let session = Session::begin(config).await.unwrap(); - session.run().await; + session.run(stopwait).await; }); wait_signal().await; - task.abort(); + stop.notify_waiters(); + task.await; } #[cfg(unix)] diff --git a/src/mqtt.rs b/src/mqtt.rs index 0bb5013..6c8e6c7 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; @@ -7,8 +8,9 @@ use paho_mqtt::{ AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, CreateOptionsBuilder, ServerResponse, SslOptionsBuilder, }; +use tokio::sync::{Mutex, Notify}; use tokio_stream::StreamExt; -use tracing::{debug, info, trace}; +use tracing::{debug, error, info, trace}; use crate::config::Configuration; use crate::hass::{self, HassConfig}; @@ -25,13 +27,17 @@ pub enum MessageType { pub struct MqttClient<'a> { config: &'a Configuration, - client: AsyncClient, + client: Arc>, stream: AsyncReceiver>, topics: TopicMatcher, + stop: Arc, } impl<'a> MqttClient<'a> { - pub fn new(config: &'a Configuration) -> Result { + pub fn new( + config: &'a Configuration, + stop: Arc, + ) -> Result { let uri = format!( "{}://{}:{}", if config.mqtt.tls { "ssl" } else { "tcp" }, @@ -45,19 +51,21 @@ impl<'a> MqttClient<'a> { .finalize(); let mut client = AsyncClient::new(client_opts)?; let stream = client.get_stream(10); + let client = Arc::new(Mutex::new(client)); let topics = TopicMatcher::new(); Ok(Self { config, client, stream, topics, + stop, }) } pub async fn connect(&mut self) -> Result { let opts = self.conn_opts()?; trace!("Connect options: {:?}", opts); - let res = self.client.connect(opts).await?; + let res = self.client.lock().await.connect(opts).await?; info!("Successfully connected to MQTT broker"); Ok(res) } @@ -65,7 +73,7 @@ impl<'a> MqttClient<'a> { pub async fn subscribe(&mut self) -> Result { let prefix = &self.config.mqtt.topic_prefix; let t_nav = format!("{}/+/navigate", prefix); - let res = self.client.subscribe(&t_nav, 0).await?; + let res = self.client.lock().await.subscribe(&t_nav, 0).await?; self.topics.insert(t_nav, MessageType::Navigate); Ok(res) } @@ -73,7 +81,7 @@ impl<'a> MqttClient<'a> { pub fn publisher(&mut self) -> MqttPublisher { MqttPublisher { config: self.config, - client: &mut self.client, + client: self.client.clone(), } } @@ -83,8 +91,17 @@ impl<'a> MqttClient<'a> { { let publisher = MqttPublisher { config: self.config, - client: &mut self.client, + client: self.client.clone(), }; + let msg = self.offline_message(); + tokio::spawn(async move { + self.stop.notified().await; + let client = self.client.lock().await; + if let Err(e) = client.publish(msg).await { + error!("Failed to publish offline message: {}", e); + } + client.disconnect(None); + }); while let Some(msg) = self.stream.next().await { let Some(msg) = msg else {continue}; trace!("Received message: {:?}", msg); @@ -100,11 +117,7 @@ impl<'a> MqttClient<'a> { fn conn_opts(&self) -> Result { let mut conn_opts = ConnectOptionsBuilder::new(); - conn_opts.will_message(Message::new_retained( - format!("{}/available", self.config.mqtt.topic_prefix), - "offline", - 0, - )); + conn_opts.will_message(self.offline_message()); conn_opts.automatic_reconnect( Duration::from_millis(500), Duration::from_secs(30), @@ -122,11 +135,19 @@ impl<'a> MqttClient<'a> { } Ok(conn_opts.finalize()) } + + fn offline_message(&self) -> Message { + Message::new_retained( + format!("{}/available", self.config.mqtt.topic_prefix), + "offline", + 0, + ) + } } pub struct MqttPublisher<'a> { config: &'a Configuration, - client: &'a mut AsyncClient, + client: Arc>, } impl<'a> MqttPublisher<'a> { @@ -138,7 +159,7 @@ impl<'a> MqttPublisher<'a> { let topic = format!("{}/{}/title", self.config.mqtt.topic_prefix, screen); let msg = Message::new_retained(topic, title, 0); - self.client.publish(msg).await?; + (*self.client.lock().await).publish(msg).await?; Ok(()) } @@ -150,7 +171,7 @@ impl<'a> MqttPublisher<'a> { let topic = format!("{}/{}/url", self.config.mqtt.topic_prefix, screen); let msg = Message::new_retained(topic, url, 0); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; Ok(()) } @@ -171,7 +192,7 @@ impl<'a> MqttPublisher<'a> { let object_id = unique_id.clone(); let msg = Message::new_retained(&availability_topic, "online", 0); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; let config = HassConfig { availability_topic, command_topic, @@ -189,7 +210,7 @@ impl<'a> MqttPublisher<'a> { 0, ); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; let unique_id = format!("sensor.{}_title", key); let object_id = unique_id.clone(); @@ -209,9 +230,10 @@ impl<'a> MqttPublisher<'a> { ); trace!("Publishing message: {:?}", msg); - self.client.publish(msg).await?; + self.client.lock().await.publish(msg).await?; info!("Succesfully published Home Assistant config"); Ok(()) } + } diff --git a/src/session.rs b/src/session.rs index eae20f2..b4ba55a 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::Notify; use tracing::{debug, error, info, trace, warn}; use crate::browser::{Browser, BrowserError}; @@ -98,7 +100,7 @@ impl Session { }) } - pub async fn run(mut self) { + pub async fn run(mut self, stop: Arc) { let windows = match self.init_windows().await { Ok(w) => w, Err(e) => { @@ -107,7 +109,7 @@ impl Session { } }; - let mut client = MqttClient::new(&self.config).unwrap(); + let mut client = MqttClient::new(&self.config, stop).unwrap(); loop { if let Err(e) = client.connect().await { warn!("Failed to connect to MQTT server: {}", e);