diff --git a/homie-device/src/lib.rs b/homie-device/src/lib.rs index 47ffb752..e2eefd58 100644 --- a/homie-device/src/lib.rs +++ b/homie-device/src/lib.rs @@ -9,14 +9,16 @@ use futures::FutureExt; use mac_address::get_mac_address; use rumqttc::{ self, AsyncClient, ClientError, ConnectionError, Event, EventLoop, Incoming, LastWill, - MqttOptions, QoS, + MqttOptions, Outgoing, QoS, StateError, }; use std::fmt::{self, Debug, Display, Formatter}; use std::future::Future; use std::pin::Pin; use std::str; +use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; +use tokio::sync::Mutex; use tokio::task::{self, JoinError, JoinHandle}; use tokio::time::sleep; @@ -30,6 +32,10 @@ const HOMIE_IMPLEMENTATION: &str = "homie-rs"; const STATS_INTERVAL: Duration = Duration::from_secs(60); const REQUESTS_CAP: usize = 10; +/// The default duration to wait between attempts to reconnect to the MQTT broker if the connection +/// is lost. +pub const DEFAULT_RECONNECT_INTERVAL: Duration = Duration::from_secs(5); + /// Error type for futures representing tasks spawned by this crate. #[derive(Error, Debug)] pub enum SpawnError { @@ -100,6 +106,7 @@ pub struct HomieDeviceBuilder { firmware_version: Option, mqtt_options: MqttOptions, update_callback: Option, + reconnect_interval: Duration, } impl Debug for HomieDeviceBuilder { @@ -127,6 +134,14 @@ impl HomieDeviceBuilder { self.firmware_version = Some(firmware_version.to_string()); } + /// Set the duration to wait between attempts to reconnect to the MQTT broker if the connection + /// is lost. + /// + /// If this is not set it will default to `DEFAULT_RECONNECT_INTERVAL`. + pub fn set_reconnect_interval(&mut self, reconnect_interval: Duration) { + self.reconnect_interval = reconnect_interval; + } + pub fn set_update_callback(&mut self, mut update_callback: F) where F: (FnMut(String, String, String) -> Fut) + Send + Sync + 'static, @@ -182,6 +197,8 @@ impl HomieDeviceBuilder { ); last_will.retain = true; mqtt_options.set_last_will(last_will); + // Setting this to false means that our subscriptions will be kept when we reconnect. + mqtt_options.set_clean_session(false); let (client, event_loop) = AsyncClient::new(mqtt_options, REQUESTS_CAP); let publisher = DevicePublisher::new(client, self.device_base); @@ -201,7 +218,12 @@ impl HomieDeviceBuilder { None }; - let homie = HomieDevice::new(publisher, self.device_name, &extension_ids); + let homie = HomieDevice::new( + publisher, + self.device_name, + &extension_ids, + self.reconnect_interval, + ); (event_loop, homie, stats, firmware, self.update_callback) } @@ -211,11 +233,8 @@ impl HomieDeviceBuilder { /// single MQTT connection. #[derive(Debug)] pub struct HomieDevice { - publisher: DevicePublisher, - device_name: String, - nodes: Vec, - state: State, - extension_ids: String, + state: Arc>, + reconnect_interval: Duration, } impl HomieDevice { @@ -240,34 +259,48 @@ impl HomieDevice { firmware_version: None, mqtt_options, update_callback: None, + reconnect_interval: DEFAULT_RECONNECT_INTERVAL, } } - fn new(publisher: DevicePublisher, device_name: String, extension_ids: &[&str]) -> HomieDevice { + fn new( + publisher: DevicePublisher, + device_name: String, + extension_ids: &[&str], + reconnect_interval: Duration, + ) -> HomieDevice { HomieDevice { - publisher, - device_name, - nodes: vec![], - state: State::Disconnected, - extension_ids: extension_ids.join(","), + state: Arc::new(Mutex::new(DeviceState { + publisher, + device_name, + nodes: vec![], + state: State::Disconnected, + extension_ids: extension_ids.join(","), + })), + reconnect_interval, } } async fn start(&mut self) -> Result<(), ClientError> { - assert_eq!(self.state, State::Disconnected); - self.publisher + let mut state = self.state.lock().await; + assert_eq!(state.state, State::Disconnected); + state + .publisher .publish_retained("$homie", HOMIE_VERSION) .await?; - self.publisher - .publish_retained("$extensions", self.extension_ids.as_str()) + state + .publisher + .publish_retained("$extensions", state.extension_ids.as_str()) .await?; - self.publisher + state + .publisher .publish_retained("$implementation", HOMIE_IMPLEMENTATION) .await?; - self.publisher - .publish_retained("$name", self.device_name.as_str()) + state + .publisher + .publish_retained("$name", state.device_name.as_str()) .await?; - self.set_state(State::Init).await?; + state.set_state(State::Init).await?; Ok(()) } @@ -277,29 +310,61 @@ impl HomieDevice { mut event_loop: EventLoop, mut update_callback: Option, ) -> impl Future> { - let device_base = format!("{}/", self.publisher.device_base); let (incoming_tx, incoming_rx) = async_channel::unbounded(); + let reconnect_interval = self.reconnect_interval; let mqtt_task = task::spawn(async move { + let mut disconnect_requested = false; loop { - let notification = event_loop.poll().await?; - log::trace!("Notification = {:?}", notification); - - if let Event::Incoming(incoming) = notification { - incoming_tx.send(incoming).await.map_err(|_| { - SpawnError::Internal("Incoming event channel receiver closed.") - })?; + match event_loop.poll().await { + Ok(notification) => { + log::trace!("Notification = {:?}", notification); + + match notification { + Event::Incoming(incoming) => { + incoming_tx.send(incoming).await.map_err(|_| { + SpawnError::Internal("Incoming event channel receiver closed.") + })?; + } + Event::Outgoing(Outgoing::Disconnect) => { + // Flag that we have tried to disconnect intentionally, but keep + // polling until we get an error implying that we have actually disconnected. + disconnect_requested = true; + } + Event::Outgoing(_) => {} + } + } + Err(e) => { + if disconnect_requested { + log::trace!("Disconnected as requested."); + return Ok(()); + } + log::error!("Failed to poll EventLoop: {}", e); + match e { + ConnectionError::Io(_) => { + sleep(reconnect_interval).await; + } + ConnectionError::MqttState(StateError::AwaitPingResp) => {} + _ => return Err(e.into()), + } + } } } }); - let publisher = self.publisher.clone(); - let incoming_task: JoinHandle> = - task::spawn(async move { - loop { - if let Incoming::Publish(publish) = incoming_rx.recv().await.map_err(|_| { - SpawnError::Internal("Incoming event channel sender closed.") - })? { + let state = self.state.clone(); + let incoming_task: JoinHandle> = task::spawn(async move { + let publisher = state.lock().await.publisher.clone(); + let device_base = format!("{}/", publisher.device_base); + loop { + match incoming_rx.recv().await { + Err(_) => { + // This happens when the task above exits either because of an error or + // because we disconnected intentionally. + log::trace!("Incoming event channel sender closed."); + return Ok(()); + } + Ok(Incoming::Publish(publish)) => { if let Some(rest) = publish.topic.strip_prefix(&device_base) { if let ([node_id, property_id, "set"], Ok(payload)) = ( rest.split('/').collect::>().as_slice(), @@ -332,8 +397,14 @@ impl HomieDevice { log::warn!("Unexpected publish: {:?}", publish); } } + Ok(Incoming::ConnAck(_)) => { + // TODO: Only if this is not the initial connection. + state.lock().await.republish_all().await?; + } + _ => {} } - }); + } + }); try_join_unit_handles(mqtt_task, incoming_task) } @@ -342,135 +413,46 @@ impl HomieDevice { /// This will panic if you attempt to add a node with the same ID as a node which was previously /// added. pub async fn add_node(&mut self, node: Node) -> Result<(), ClientError> { - // First check that there isn't already a node with the same ID. - if self.nodes.iter().any(|n| n.id == node.id) { - panic!("Tried to add node with duplicate ID: {:?}", node); - } - self.nodes.push(node); - // `node` was moved into the `nodes` vector, but we can safely get a reference to it because - // nothing else can modify `nodes` in the meantime. - let node = &self.nodes[self.nodes.len() - 1]; - - self.publish_node(&node).await?; - self.publish_nodes().await + self.state.lock().await.add_node(node).await } /// Remove the node with the given ID. pub async fn remove_node(&mut self, node_id: &str) -> Result<(), ClientError> { - // Panic on attempt to remove a node which was never added. - let index = self.nodes.iter().position(|n| n.id == node_id).unwrap(); - self.unpublish_node(&self.nodes[index]).await?; - self.nodes.remove(index); - self.publish_nodes().await - } - - async fn publish_node(&self, node: &Node) -> Result<(), ClientError> { - self.publisher - .publish_retained(&format!("{}/$name", node.id), node.name.as_str()) - .await?; - self.publisher - .publish_retained(&format!("{}/$type", node.id), node.node_type.as_str()) - .await?; - let mut property_ids: Vec<&str> = vec![]; - for property in &node.properties { - property_ids.push(&property.id); - self.publisher - .publish_retained( - &format!("{}/{}/$name", node.id, property.id), - property.name.as_str(), - ) - .await?; - self.publisher - .publish_retained( - &format!("{}/{}/$datatype", node.id, property.id), - property.datatype, - ) - .await?; - self.publisher - .publish_retained( - &format!("{}/{}/$settable", node.id, property.id), - if property.settable { "true" } else { "false" }, - ) - .await?; - if let Some(unit) = &property.unit { - self.publisher - .publish_retained(&format!("{}/{}/$unit", node.id, property.id), unit.as_str()) - .await?; - } - if let Some(format) = &property.format { - self.publisher - .publish_retained( - &format!("{}/{}/$format", node.id, property.id), - format.as_str(), - ) - .await?; - } - if property.settable { - self.publisher - .subscribe(&format!("{}/{}/set", node.id, property.id)) - .await?; - } - } - self.publisher - .publish_retained(&format!("{}/$properties", node.id), property_ids.join(",")) - .await?; - Ok(()) - } - - async fn unpublish_node(&self, node: &Node) -> Result<(), ClientError> { - for property in &node.properties { - if property.settable { - self.publisher - .unsubscribe(&format!("{}/{}/set", node.id, property.id)) - .await?; - } - } - Ok(()) - } - - async fn publish_nodes(&mut self) -> Result<(), ClientError> { - let node_ids = self - .nodes - .iter() - .map(|node| node.id.as_str()) - .collect::>() - .join(","); - self.publisher.publish_retained("$nodes", node_ids).await - } - - async fn set_state(&mut self, state: State) -> Result<(), ClientError> { - self.state = state; - self.publisher.publish_retained("$state", self.state).await + self.state.lock().await.remove_node(node_id).await } /// Update the [state](https://homieiot.github.io/specification/#device-lifecycle) of the Homie /// device to 'ready'. This should be called once it is ready to begin normal operation, or to /// return to normal operation after calling `sleep()` or `alert()`. pub async fn ready(&mut self) -> Result<(), ClientError> { - assert!(&[State::Init, State::Sleeping, State::Alert].contains(&self.state)); - self.set_state(State::Ready).await + let mut state = self.state.lock().await; + assert!(&[State::Init, State::Sleeping, State::Alert].contains(&state.state)); + state.set_state(State::Ready).await } /// Update the [state](https://homieiot.github.io/specification/#device-lifecycle) of the Homie /// device to 'sleeping'. This should be only be called after `ready()`, otherwise it will panic. pub async fn sleep(&mut self) -> Result<(), ClientError> { - assert_eq!(self.state, State::Ready); - self.set_state(State::Sleeping).await + let mut state = self.state.lock().await; + assert_eq!(state.state, State::Ready); + state.set_state(State::Sleeping).await } /// Update the [state](https://homieiot.github.io/specification/#device-lifecycle) of the Homie /// device to 'alert', to indicate that something wrong is happening and manual intervention may /// be required. This should be only be called after `ready()`, otherwise it will panic. pub async fn alert(&mut self) -> Result<(), ClientError> { - assert_eq!(self.state, State::Ready); - self.set_state(State::Alert).await + let mut state = self.state.lock().await; + assert_eq!(state.state, State::Ready); + state.set_state(State::Alert).await } /// Disconnect cleanly from the MQTT broker, after updating the state of the Homie device to // 'disconnected'. - pub async fn disconnect(mut self) -> Result<(), ClientError> { - self.set_state(State::Disconnected).await?; - self.publisher.client.disconnect().await + pub async fn disconnect(self) -> Result<(), ClientError> { + let mut state = self.state.lock().await; + state.set_state(State::Disconnected).await?; + state.publisher.client.disconnect().await } /// Publish a new value for the given property of the given node of this device. The caller is @@ -481,12 +463,96 @@ impl HomieDevice { property_id: &str, value: impl ToString, ) -> Result<(), ClientError> { - self.publisher + // TODO: If we are disconnected, just keep track of the latest value for each property to + // publish after reconnecting, rather than queuing these all up. + self.state + .lock() + .await + .publisher .publish_retained(&format!("{}/{}", node_id, property_id), value.to_string()) .await } } +/// The internal state of a `HomieDevice`, so it can be shared between threads. +#[derive(Debug)] +struct DeviceState { + publisher: DevicePublisher, + device_name: String, + nodes: Vec, + state: State, + extension_ids: String, +} + +impl DeviceState { + /// Publish the current list of node IDs. + async fn publish_nodes(&self) -> Result<(), ClientError> { + let node_ids = self + .nodes + .iter() + .map(|node| node.id.as_str()) + .collect::>() + .join(","); + self.publisher.publish_retained("$nodes", node_ids).await + } + + /// Set the state to the given value, and publish it. + async fn set_state(&mut self, state: State) -> Result<(), ClientError> { + self.state = state; + self.send_state().await + } + + /// Publish the current state. + async fn send_state(&self) -> Result<(), ClientError> { + self.publisher.publish_retained("$state", self.state).await + } + + async fn republish_all(&self) -> Result<(), ClientError> { + for node in &self.nodes { + self.publisher.publish_node(node).await?; + } + self.publish_nodes().await?; + // TODO: Stats and firmware extensions + self.publisher + .publish_retained("$homie", HOMIE_VERSION) + .await?; + self.publisher + .publish_retained("$extensions", self.extension_ids.as_str()) + .await?; + self.publisher + .publish_retained("$implementation", HOMIE_IMPLEMENTATION) + .await?; + self.publisher + .publish_retained("$name", self.device_name.as_str()) + .await?; + self.send_state().await + } + + /// Add a node to the Homie device and publish it. + async fn add_node(&mut self, node: Node) -> Result<(), ClientError> { + // First check that there isn't already a node with the same ID. + if self.nodes.iter().any(|n| n.id == node.id) { + panic!("Tried to add node with duplicate ID: {:?}", node); + } + self.nodes.push(node); + // `node` was moved into the `nodes` vector, but we can safely get a reference to it because + // nothing else can modify `nodes` in the meantime. + let node = &self.nodes[self.nodes.len() - 1]; + + self.publisher.publish_node(&node).await?; + self.publish_nodes().await + } + + /// Remove the node with the given ID. + async fn remove_node(&mut self, node_id: &str) -> Result<(), ClientError> { + // Panic on attempt to remove a node which was never added. + let index = self.nodes.iter().position(|n| n.id == node_id).unwrap(); + self.publisher.unpublish_node(&self.nodes[index]).await?; + self.nodes.remove(index); + self.publish_nodes().await + } +} + #[derive(Clone, Debug)] struct DevicePublisher { pub client: AsyncClient, @@ -521,6 +587,61 @@ impl DevicePublisher { let topic = format!("{}/{}", self.device_base, subtopic); self.client.unsubscribe(topic).await } + + /// Publish metadata about the given node and its properties. + async fn publish_node(&self, node: &Node) -> Result<(), ClientError> { + self.publish_retained(&format!("{}/$name", node.id), node.name.as_str()) + .await?; + self.publish_retained(&format!("{}/$type", node.id), node.node_type.as_str()) + .await?; + let mut property_ids: Vec<&str> = vec![]; + for property in &node.properties { + property_ids.push(&property.id); + self.publish_retained( + &format!("{}/{}/$name", node.id, property.id), + property.name.as_str(), + ) + .await?; + self.publish_retained( + &format!("{}/{}/$datatype", node.id, property.id), + property.datatype, + ) + .await?; + self.publish_retained( + &format!("{}/{}/$settable", node.id, property.id), + if property.settable { "true" } else { "false" }, + ) + .await?; + if let Some(unit) = &property.unit { + self.publish_retained(&format!("{}/{}/$unit", node.id, property.id), unit.as_str()) + .await?; + } + if let Some(format) = &property.format { + self.publish_retained( + &format!("{}/{}/$format", node.id, property.id), + format.as_str(), + ) + .await?; + } + if property.settable { + self.subscribe(&format!("{}/{}/set", node.id, property.id)) + .await?; + } + } + self.publish_retained(&format!("{}/$properties", node.id), property_ids.join(",")) + .await?; + Ok(()) + } + + async fn unpublish_node(&self, node: &Node) -> Result<(), ClientError> { + for property in &node.properties { + if property.settable { + self.unsubscribe(&format!("{}/{}/set", node.id, property.id)) + .await?; + } + } + Ok(()) + } } /// Legacy stats extension. @@ -552,6 +673,7 @@ impl HomieStats { fn spawn(self) -> impl Future> { let task: JoinHandle> = task::spawn(async move { loop { + // TODO: Break out of this loop if disconnection is requested. let uptime = Instant::now() - self.start_time; self.publisher .publish_retained("$stats/uptime", uptime.as_secs().to_string()) @@ -636,7 +758,12 @@ mod tests { let (cancel_tx, _cancel_rx) = async_channel::unbounded(); let client = AsyncClient::from_senders(requests_tx, cancel_tx); let publisher = DevicePublisher::new(client, "homie/test-device".to_string()); - let device = HomieDevice::new(publisher, "Test device".to_string(), &[]); + let device = HomieDevice::new( + publisher, + "Test device".to_string(), + &[], + DEFAULT_RECONNECT_INTERVAL, + ); (device, requests_rx) } @@ -744,8 +871,9 @@ mod tests { let (_event_loop, homie, _stats, firmware, _callback) = builder.build(); - assert_eq!(homie.device_name, "Test device"); - assert_eq!(homie.publisher.device_base, "homie/test-device"); + let state = homie.state.lock().await; + assert_eq!(state.device_name, "Test device"); + assert_eq!(state.publisher.device_base, "homie/test-device"); assert!(firmware.is_none()); Ok(()) @@ -763,8 +891,9 @@ mod tests { let (_event_loop, homie, _stats, firmware, _callback) = builder.build(); - assert_eq!(homie.device_name, "Test device"); - assert_eq!(homie.publisher.device_base, "homie/test-device"); + let state = homie.state.lock().await; + assert_eq!(state.device_name, "Test device"); + assert_eq!(state.publisher.device_base, "homie/test-device"); let firmware = firmware.unwrap(); assert_eq!(firmware.firmware_name, "firmware_name"); assert_eq!(firmware.firmware_version, "firmware_version"); diff --git a/mijia-homie/src/config.rs b/mijia-homie/src/config.rs index 69cd39a5..bec0328c 100644 --- a/mijia-homie/src/config.rs +++ b/mijia-homie/src/config.rs @@ -1,4 +1,5 @@ use eyre::Report; +use homie_device::DEFAULT_RECONNECT_INTERVAL; use mijia::bluetooth::{MacAddress, ParseMacAddressError}; use rumqttc::{MqttOptions, Transport}; use rustls::ClientConfig; @@ -45,6 +46,11 @@ pub struct MqttConfig { pub username: Option, pub password: Option, pub client_name: Option, + #[serde( + deserialize_with = "de_duration_seconds", + rename = "reconnect_interval_seconds" + )] + pub reconnect_interval: Duration, } impl Default for MqttConfig { @@ -56,6 +62,7 @@ impl Default for MqttConfig { username: None, password: None, client_name: None, + reconnect_interval: DEFAULT_RECONNECT_INTERVAL, } } } diff --git a/mijia-homie/src/main.rs b/mijia-homie/src/main.rs index 43abf4cc..34f515c9 100644 --- a/mijia-homie/src/main.rs +++ b/mijia-homie/src/main.rs @@ -37,11 +37,13 @@ async fn main() -> Result<(), eyre::Report> { let config = Config::from_file()?; let sensor_names = read_sensor_names(&config.homie.sensor_names_filename)?; + let reconnect_interval = config.mqtt.reconnect_interval; let mqtt_options = get_mqtt_options(config.mqtt, &config.homie.device_id); let device_base = format!("{}/{}", config.homie.prefix, config.homie.device_id); let mut homie_builder = HomieDevice::builder(&device_base, &config.homie.device_name, mqtt_options); homie_builder.set_firmware(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")); + homie_builder.set_reconnect_interval(reconnect_interval); let (homie, homie_handle) = homie_builder.spawn().await?; // Connect a Bluetooth session.