Skip to content

Commit

Permalink
feat: support to server signed ca on tls
Browse files Browse the repository at this point in the history
  • Loading branch information
vitorfdl committed Jun 13, 2024
1 parent 0f028a4 commit aacc6e4
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rand = "0.8.5"
regex = "1.10.4"
reqwest = { version = "0.12.4", features = ["rustls-tls", "json"] }
rumqttc = { version = "0.24.0" }
rustls-native-certs = "0.7.0"
serde = "1.0.203"
serde_json = "1.0.117"
tokio = {version = "1.37.0", features = ["full", "rt-multi-thread"]}
Expand Down
27 changes: 14 additions & 13 deletions src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl RelayConfig {

Ok(RelayConfig {
id,
config: config.with_defaults(),
config: config.with_defaults()?,
profile_id,
state,
})
Expand All @@ -51,33 +51,34 @@ impl RelayConfig {
}

impl ConfigFile {
pub fn with_defaults(mut self) -> Self {
pub fn with_defaults(mut self) -> anyhow::Result<Self> {
if self.tagoio_url.is_none() {
self.tagoio_url = Some("https://api.tago.io".to_string());
}
if self.downlink_port.is_none() {
self.downlink_port = Some("3000".to_string());
}
self.mqtt = self.mqtt.with_defaults();
self
self.mqtt = self.mqtt.with_defaults()?;
Ok(self)
}
}

impl MQTT {
pub fn with_defaults(mut self) -> Self {
pub fn with_defaults(mut self) -> anyhow::Result<Self> {
if self.client_id.is_none() {
self.client_id = Some("tagoio-relay".to_string());
}

if self.address != "localhost" && !self.is_valid_address(&self.address) {
panic!("Invalid MQTT address: {}", self.address);
return Err(anyhow::anyhow!("Invalid MQTT address: {}", self.address));
}

self
Ok(self)
}

fn is_valid_address(&self, address: &str) -> bool {
let re = Regex::new(r"^(?:(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:\d{1,3}\.){3}\d{1,3}|localhost)$").unwrap();
let re = Regex::new(
r"^(?:(?:ws|wss|mqtt|mqtts)://)?(?:(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:\d{1,3}\.){3}\d{1,3}|localhost)$",
)
.unwrap();
re.is_match(address)
}
}
Expand Down Expand Up @@ -163,7 +164,7 @@ mod tests {
},
};

let config_with_defaults = config.with_defaults();
let config_with_defaults = config.with_defaults().unwrap();

assert_eq!(config_with_defaults.tagoio_url.unwrap(), "https://api.tago.io");
assert_eq!(config_with_defaults.downlink_port.unwrap(), "3000");
Expand All @@ -185,7 +186,7 @@ mod tests {
broker_tls_key: None,
};

let mqtt_with_defaults = mqtt.with_defaults();
let mqtt_with_defaults = mqtt.with_defaults().unwrap();

assert_eq!(mqtt_with_defaults.client_id.unwrap(), "tagoio-relay");
}
Expand All @@ -205,7 +206,7 @@ mod tests {
broker_tls_cert: None,
broker_tls_key: None,
};
mqtt.with_defaults();
mqtt.with_defaults().unwrap();
}

// #[test]
Expand Down
27 changes: 19 additions & 8 deletions src/services/mqttrelay.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::{schema::RelayConfig, utils::calculate_backoff};
use rumqttc::{AsyncClient, MqttOptions, QoS, TlsConfiguration};
use rumqttc::{
tokio_rustls::rustls::{ClientConfig, RootCertStore},
AsyncClient, MqttOptions, QoS, TlsConfiguration,
};
use serde::Deserialize;
use std::sync::Arc;
use tokio::{
sync::{mpsc, Mutex},
time::{sleep, Duration},
};

const BACKOFF_MAX_RETRIES: u32 = 20;
#[derive(Deserialize)]
pub struct PublishMessage {
Expand Down Expand Up @@ -95,16 +97,25 @@ fn initialize_mqtt_options(relay_cfg: &RelayConfig) -> MqttOptions {
alpn: None,
client_auth: Some(client_auth),
}));
} else {
// Use rustls-native-certs to load root certificates from the operating system.
let mut root_cert_store = RootCertStore::empty();
root_cert_store
.add_parsable_certificates(rustls_native_certs::load_native_certs().expect("could not load platform certs"));

let client_config = ClientConfig::builder()
.with_root_certificates(Arc::new(root_cert_store))
.with_no_client_auth();

mqttoptions.set_transport(rumqttc::Transport::tls_with_config(client_config.into()));
}
}

if let Some(username) = username {
if ca_file.is_some() {
mqttoptions.set_credentials(
username,
password.as_ref().expect("Password must be provided if username is set"),
);
}
mqttoptions.set_credentials(
username,
password.as_ref().expect("Password must be provided if username is set"),
);
}

mqttoptions
Expand Down
5 changes: 3 additions & 2 deletions src/services/tagoio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ use crate::{schema::RelayConfig, CONFIG_FILE};
pub async fn get_relay_list() -> Result<Vec<Arc<RelayConfig>>, Error> {
let config_file = CONFIG_FILE.read().unwrap();
if let Some(config) = &*config_file {
log::info!(target: "info", "Config file loaded successfully");
let relay = RelayConfig::new_with_defaults(None, config.clone()).unwrap();
let relay = RelayConfig::new_with_defaults(None, config.clone())?;
let relays: Vec<Arc<RelayConfig>> = vec![Arc::new(relay)];

log::info!(target: "info", "Config file loaded successfully");

return Ok(relays);
}

Expand Down
9 changes: 9 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ struct ConfigFileResponse {

const DEFAULT_CONFIG: &str = include_str!("./default_config.toml");

/**
* Get the path to the configuration file
*/
fn get_config_path(user_path: Option<String>) -> std::path::PathBuf {
let env_config_path = if user_path.is_none() {
std::env::var("CONFIG_PATH").ok()
Expand All @@ -30,6 +33,9 @@ fn get_config_path(user_path: Option<String>) -> std::path::PathBuf {
config_path
}

/**
* Initialize the configuration file
*/
pub fn init_config(user_path: Option<impl AsRef<str>>) {
let config_path = get_config_path(user_path.map(|s| s.as_ref().to_string()));
if config_path.exists() {
Expand All @@ -42,6 +48,9 @@ pub fn init_config(user_path: Option<impl AsRef<str>>) {
println!("Configuration file created at {}", config_path.display());
}

/**
* Fetch the configuration file
*/
pub fn fetch_config_file(user_path: Option<String>) -> Option<ConfigFile> {
let config_path = get_config_path(user_path);
// If the config file doesn't exist, create it
Expand Down

0 comments on commit aacc6e4

Please sign in to comment.