diff --git a/Cargo.lock b/Cargo.lock index 7a19275..f9dbaff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2004,6 +2004,7 @@ dependencies = [ "regex", "reqwest", "rumqttc", + "rustls-native-certs", "serde", "serde_json", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 550d095..5e3541a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ rand = "0.8.5" regex = "1.10.4" reqwest = { version = "0.12.4", features = ["rustls-tls", "json"] } rumqttc = { version = "0.24.0" } +rustls-native-certs = "0.7.0" serde = "1.0.203" serde_json = "1.0.117" tokio = {version = "1.37.0", features = ["full", "rt-multi-thread"]} diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 3908925..3f06b74 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -39,7 +39,7 @@ impl RelayConfig { Ok(RelayConfig { id, - config: config.with_defaults(), + config: config.with_defaults()?, profile_id, state, }) @@ -51,33 +51,34 @@ impl RelayConfig { } impl ConfigFile { - pub fn with_defaults(mut self) -> Self { + pub fn with_defaults(mut self) -> anyhow::Result { if self.tagoio_url.is_none() { self.tagoio_url = Some("https://api.tago.io".to_string()); } if self.downlink_port.is_none() { self.downlink_port = Some("3000".to_string()); } - self.mqtt = self.mqtt.with_defaults(); - self + self.mqtt = self.mqtt.with_defaults()?; + Ok(self) } } impl MQTT { - pub fn with_defaults(mut self) -> Self { + pub fn with_defaults(mut self) -> anyhow::Result { if self.client_id.is_none() { self.client_id = Some("tagoio-relay".to_string()); } - if self.address != "localhost" && !self.is_valid_address(&self.address) { - panic!("Invalid MQTT address: {}", self.address); + return Err(anyhow::anyhow!("Invalid MQTT address: {}", self.address)); } - - self + Ok(self) } fn is_valid_address(&self, address: &str) -> bool { - let re = Regex::new(r"^(?:(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:\d{1,3}\.){3}\d{1,3}|localhost)$").unwrap(); + let re = Regex::new( + r"^(?:(?:ws|wss|mqtt|mqtts)://)?(?:(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:\d{1,3}\.){3}\d{1,3}|localhost)$", + ) + .unwrap(); re.is_match(address) } } @@ -163,7 +164,7 @@ mod tests { }, }; - let config_with_defaults = config.with_defaults(); + let config_with_defaults = config.with_defaults().unwrap(); assert_eq!(config_with_defaults.tagoio_url.unwrap(), "https://api.tago.io"); assert_eq!(config_with_defaults.downlink_port.unwrap(), "3000"); @@ -185,7 +186,7 @@ mod tests { broker_tls_key: None, }; - let mqtt_with_defaults = mqtt.with_defaults(); + let mqtt_with_defaults = mqtt.with_defaults().unwrap(); assert_eq!(mqtt_with_defaults.client_id.unwrap(), "tagoio-relay"); } @@ -205,7 +206,7 @@ mod tests { broker_tls_cert: None, broker_tls_key: None, }; - mqtt.with_defaults(); + mqtt.with_defaults().unwrap(); } // #[test] diff --git a/src/services/mqttrelay.rs b/src/services/mqttrelay.rs index a46a844..9ca7a07 100644 --- a/src/services/mqttrelay.rs +++ b/src/services/mqttrelay.rs @@ -1,12 +1,14 @@ use crate::{schema::RelayConfig, utils::calculate_backoff}; -use rumqttc::{AsyncClient, MqttOptions, QoS, TlsConfiguration}; +use rumqttc::{ + tokio_rustls::rustls::{ClientConfig, RootCertStore}, + AsyncClient, MqttOptions, QoS, TlsConfiguration, +}; use serde::Deserialize; use std::sync::Arc; use tokio::{ sync::{mpsc, Mutex}, time::{sleep, Duration}, }; - const BACKOFF_MAX_RETRIES: u32 = 20; #[derive(Deserialize)] pub struct PublishMessage { @@ -95,16 +97,25 @@ fn initialize_mqtt_options(relay_cfg: &RelayConfig) -> MqttOptions { alpn: None, client_auth: Some(client_auth), })); + } else { + // Use rustls-native-certs to load root certificates from the operating system. + let mut root_cert_store = RootCertStore::empty(); + root_cert_store + .add_parsable_certificates(rustls_native_certs::load_native_certs().expect("could not load platform certs")); + + let client_config = ClientConfig::builder() + .with_root_certificates(Arc::new(root_cert_store)) + .with_no_client_auth(); + + mqttoptions.set_transport(rumqttc::Transport::tls_with_config(client_config.into())); } } if let Some(username) = username { - if ca_file.is_some() { - mqttoptions.set_credentials( - username, - password.as_ref().expect("Password must be provided if username is set"), - ); - } + mqttoptions.set_credentials( + username, + password.as_ref().expect("Password must be provided if username is set"), + ); } mqttoptions diff --git a/src/services/tagoio.rs b/src/services/tagoio.rs index 5a30723..b176c1c 100644 --- a/src/services/tagoio.rs +++ b/src/services/tagoio.rs @@ -16,10 +16,11 @@ use crate::{schema::RelayConfig, CONFIG_FILE}; pub async fn get_relay_list() -> Result>, Error> { let config_file = CONFIG_FILE.read().unwrap(); if let Some(config) = &*config_file { - log::info!(target: "info", "Config file loaded successfully"); - let relay = RelayConfig::new_with_defaults(None, config.clone()).unwrap(); + let relay = RelayConfig::new_with_defaults(None, config.clone())?; let relays: Vec> = vec![Arc::new(relay)]; + log::info!(target: "info", "Config file loaded successfully"); + return Ok(relays); } diff --git a/src/utils.rs b/src/utils.rs index c77a781..0bee6aa 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -11,6 +11,9 @@ struct ConfigFileResponse { const DEFAULT_CONFIG: &str = include_str!("./default_config.toml"); +/** + * Get the path to the configuration file + */ fn get_config_path(user_path: Option) -> std::path::PathBuf { let env_config_path = if user_path.is_none() { std::env::var("CONFIG_PATH").ok() @@ -30,6 +33,9 @@ fn get_config_path(user_path: Option) -> std::path::PathBuf { config_path } +/** + * Initialize the configuration file + */ pub fn init_config(user_path: Option>) { let config_path = get_config_path(user_path.map(|s| s.as_ref().to_string())); if config_path.exists() { @@ -42,6 +48,9 @@ pub fn init_config(user_path: Option>) { println!("Configuration file created at {}", config_path.display()); } +/** + * Fetch the configuration file + */ pub fn fetch_config_file(user_path: Option) -> Option { let config_path = get_config_path(user_path); // If the config file doesn't exist, create it