diff --git a/Cargo.lock b/Cargo.lock index 87a778b..1606e6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,7 +980,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.8" +version = "0.2.9" dependencies = [ "async-trait", "base64 0.22.1", @@ -998,6 +998,7 @@ dependencies = [ "ollama-workflows", "openssl", "parking_lot", + "port_check", "rand 0.8.5", "reqwest 0.12.8", "serde", @@ -3760,6 +3761,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "port_check" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2110609fb863cdb367d4e69d6c43c81ba6a8c7d18e80082fe9f3ef16b23afeed" + [[package]] name = "portable-atomic" version = "1.9.0" diff --git a/Cargo.toml b/Cargo.toml index 7d468b2..10cd32a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dkn-compute" -version = "0.2.8" +version = "0.2.9" edition = "2021" license = "Apache-2.0" readme = "README.md" @@ -70,6 +70,7 @@ libp2p = { git = "https://github.com/anilaltuner/rust-libp2p.git", rev = "7ce9f9 libp2p-identity = { version = "0.2.9", features = ["secp256k1"] } tracing = { version = "0.1.40" } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +port_check = "0.2.1" # Vendor OpenSSL so that its easier to build cross-platform packages [dependencies.openssl] diff --git a/src/config/mod.rs b/src/config/mod.rs index 3fa8e89..7cc3bd2 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,15 +2,16 @@ mod models; mod ollama; mod openai; -use crate::utils::crypto::to_address; +use crate::utils::{address_in_use, crypto::to_address}; use eyre::{eyre, Result}; +use libp2p::Multiaddr; use libsecp256k1::{PublicKey, SecretKey}; use models::ModelConfig; use ollama::OllamaConfig; use ollama_workflows::ModelProvider; use openai::OpenAIConfig; -use std::{env, time::Duration}; +use std::{env, str::FromStr, time::Duration}; /// Timeout duration for checking model performance during a generation. const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80); @@ -28,8 +29,8 @@ pub struct DriaComputeNodeConfig { pub address: [u8; 20], /// Admin public key, used for message authenticity. pub admin_public_key: PublicKey, - /// P2P listen address as a string, e.g. `/ip4/0.0.0.0/tcp/4001`. - pub p2p_listen_addr: String, + /// P2P listen address, e.g. `/ip4/0.0.0.0/tcp/4001`. + pub p2p_listen_addr: Multiaddr, /// Available LLM models & providers for the node. pub model_config: ModelConfig, /// Even if Ollama is not used, we store the host & port here. @@ -104,9 +105,11 @@ impl DriaComputeNodeConfig { } log::info!("Models: {:?}", model_config.models); - let p2p_listen_addr = env::var("DKN_P2P_LISTEN_ADDR") + let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR") .map(|addr| addr.trim_matches('"').to_string()) .unwrap_or(DEFAULT_P2P_LISTEN_ADDR.to_string()); + let p2p_listen_addr = Multiaddr::from_str(&p2p_listen_addr_str) + .expect("Could not parse the given P2P listen address."); Self { admin_public_key, @@ -178,6 +181,18 @@ impl DriaComputeNodeConfig { Ok(()) } } + + // ensure that listen address is free + pub fn check_address_in_use(&self) -> Result<()> { + if address_in_use(&self.p2p_listen_addr) { + return Err(eyre!( + "Listen address {} is already in use.", + self.p2p_listen_addr + )); + } + + Ok(()) + } } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index 368e535..a2fd86d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,6 +46,7 @@ async fn main() -> Result<()> { // create configurations & check required services let config = DriaComputeNodeConfig::new(); + config.check_address_in_use()?; let service_check_token = token.clone(); let mut config_clone = config.clone(); let service_check_handle = tokio::spawn(async move { diff --git a/src/node.rs b/src/node.rs index db8bc3a..e824f1c 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,6 +1,6 @@ use eyre::{eyre, Result}; -use libp2p::{gossipsub, Multiaddr}; -use std::{str::FromStr, time::Duration}; +use libp2p::gossipsub; +use std::time::Duration; use tokio_util::sync::CancellationToken; use crate::{ @@ -40,7 +40,6 @@ impl DriaComputeNode { cancellation: CancellationToken, ) -> Result { let keypair = secret_to_keypair(&config.secret_key); - let listen_addr = Multiaddr::from_str(config.p2p_listen_addr.as_str())?; // get available nodes (bootstrap, relay, rpc) for p2p let available_nodes = AvailableNodes::default() @@ -53,7 +52,7 @@ impl DriaComputeNode { ) .sort_dedup(); - let p2p = P2PClient::new(keypair, listen_addr, &available_nodes)?; + let p2p = P2PClient::new(keypair, config.p2p_listen_addr.clone(), &available_nodes)?; Ok(DriaComputeNode { p2p, diff --git a/src/p2p/client.rs b/src/p2p/client.rs index 06c9441..e389206 100644 --- a/src/p2p/client.rs +++ b/src/p2p/client.rs @@ -236,11 +236,7 @@ impl P2PClient { /// /// - For Kademlia, we check the kademlia protocol and then add the address to the Kademlia routing table. fn handle_identify_event(&mut self, peer_id: PeerId, info: identify::Info) { - // we only care about the observed address, although there may be other addresses at `info.listen_addrs` - // TODO: this may be wrong - let addr = info.observed_addr; - - // check protocol string + // check identify protocol string if info.protocol_version != P2P_PROTOCOL_STRING { log::warn!( "Identify: Peer {} has different Identify protocol: (them {}, you {})", @@ -259,17 +255,31 @@ impl P2PClient { { // if it matches our protocol, add it to the Kademlia routing table if *kad_protocol == P2P_KADEMLIA_PROTOCOL { - log::info!( - "Identify: {} peer {} identified at {}", - kad_protocol, - peer_id, - addr - ); - - self.swarm - .behaviour_mut() - .kademlia - .add_address(&peer_id, addr); + // filter listen addresses + let addrs = info.listen_addrs.into_iter().filter(|listen_addr| { + if let Some(Protocol::Ip4(ipv4_addr)) = listen_addr.iter().next() { + // ignore private & localhost addresses + !(ipv4_addr.is_private() || ipv4_addr.is_loopback()) + } else { + // ignore non ipv4 addresses + false + } + }); + + // add them to kademlia + for addr in addrs { + log::info!( + "Identify: {} peer {} identified at {}", + kad_protocol, + peer_id, + addr + ); + + self.swarm + .behaviour_mut() + .kademlia + .add_address(&peer_id, addr); + } } else { log::warn!( "Identify: Peer {} has different Kademlia version: (them {}, you {})", diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1db2ca6..23606ae 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,7 +7,12 @@ pub use message::DKNMessage; mod available_nodes; pub use available_nodes::AvailableNodes; -use std::time::{Duration, SystemTime}; +use libp2p::{multiaddr::Protocol, Multiaddr}; +use port_check::is_port_reachable; +use std::{ + net::{Ipv4Addr, SocketAddrV4}, + time::{Duration, SystemTime}, +}; /// Returns the current time in nanoseconds since the Unix epoch. /// @@ -23,6 +28,34 @@ pub fn get_current_time_nanos() -> u128 { .as_nanos() } +/// Checks if a given address is already in use locally. +/// This is mostly used to see if the P2P address is already in use. +/// +/// Simply tries to connect with TCP to the given address. +#[inline] +pub fn address_in_use(addr: &Multiaddr) -> bool { + addr.iter() + // find the port within our multiaddr + .find_map(|p| { + if let Protocol::Tcp(port) = p { + Some(port) + } else { + None + } + + // } + }) + // check if its reachable or not + .map(|port| is_port_reachable(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))) + .unwrap_or_else(|| { + log::error!( + "Could not find any TCP port in the given address: {:?}", + addr + ); + false + }) +} + /// Utility to parse comma-separated string values, mostly read from the environment. /// - Trims `"` from both ends at the start /// - For each item, trims whitespace from both ends