diff --git a/Cargo.lock b/Cargo.lock index c399b9b8776..c7a3a4ad73b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7470,6 +7470,8 @@ dependencies = [ "dashmap", "etherparse", "futures", + "ip_network", + "ip_network_table", "log", "nym-task", "tap", diff --git a/common/wireguard/Cargo.toml b/common/wireguard/Cargo.toml index b20fd260361..a78920f7f58 100644 --- a/common/wireguard/Cargo.toml +++ b/common/wireguard/Cargo.toml @@ -22,6 +22,8 @@ bytes = "1.5.0" dashmap = "5.5.3" etherparse = "0.13.0" futures = "0.3.28" +ip_network = "0.4.1" +ip_network_table = "0.2.0" log.workspace = true nym-task = { path = "../task" } tap.workspace = true diff --git a/common/wireguard/src/lib.rs b/common/wireguard/src/lib.rs index 15383eb335c..a21d7f6035a 100644 --- a/common/wireguard/src/lib.rs +++ b/common/wireguard/src/lib.rs @@ -1,40 +1,48 @@ #![cfg_attr(not(target_os = "linux"), allow(dead_code))] -use nym_task::TaskClient; - mod error; mod event; +mod network_table; mod platform; mod setup; -mod tun; mod udp_listener; +mod wg_tunnel; // Currently the module related to setting up the virtual network device is platform specific. #[cfg(target_os = "linux")] use platform::linux::tun_device; -type ActivePeers = - dashmap::DashMap>; +#[derive(Clone)] +struct TunTaskTx(tokio::sync::mpsc::UnboundedSender>); + +impl TunTaskTx { + fn send(&self, packet: Vec) -> Result<(), tokio::sync::mpsc::error::SendError>> { + self.0.send(packet) + } +} #[cfg(target_os = "linux")] pub async fn start_wireguard( - task_client: TaskClient, + task_client: nym_task::TaskClient, ) -> Result<(), Box> { + use std::sync::Arc; + // The set of active tunnels indexed by the peer's address - let active_peers = std::sync::Arc::new(ActivePeers::new()); + let active_peers = Arc::new(udp_listener::ActivePeers::new()); + let peers_by_ip = Arc::new(std::sync::Mutex::new(network_table::NetworkTable::new())); // Start the tun device that is used to relay traffic outbound - let tun_task_tx = tun_device::start_tun_device(active_peers.clone()); + let tun_task_tx = tun_device::start_tun_device(peers_by_ip.clone()); // Start the UDP listener that clients connect to - udp_listener::start_udp_listener(tun_task_tx, active_peers, task_client).await?; + udp_listener::start_udp_listener(tun_task_tx, active_peers, peers_by_ip, task_client).await?; Ok(()) } #[cfg(not(target_os = "linux"))] pub async fn start_wireguard( - _task_client: TaskClient, + _task_client: nym_task::TaskClient, ) -> Result<(), Box> { todo!("WireGuard is currently only supported on Linux") } diff --git a/common/wireguard/src/network_table.rs b/common/wireguard/src/network_table.rs new file mode 100644 index 00000000000..83008c5632d --- /dev/null +++ b/common/wireguard/src/network_table.rs @@ -0,0 +1,25 @@ +use std::net::IpAddr; + +use ip_network::IpNetwork; +use ip_network_table::IpNetworkTable; + +#[derive(Default)] +pub(crate) struct NetworkTable { + ips: IpNetworkTable, +} + +impl NetworkTable { + pub(crate) fn new() -> Self { + Self { + ips: IpNetworkTable::new(), + } + } + + pub fn insert>(&mut self, network: N, data: T) -> Option { + self.ips.insert(network, data) + } + + pub fn longest_match>(&self, ip: I) -> Option<(IpNetwork, &T)> { + self.ips.longest_match(ip) + } +} diff --git a/common/wireguard/src/platform/linux/tun_device.rs b/common/wireguard/src/platform/linux/tun_device.rs index 9d7aef79b2e..0e0b71da287 100644 --- a/common/wireguard/src/platform/linux/tun_device.rs +++ b/common/wireguard/src/platform/linux/tun_device.rs @@ -1,14 +1,17 @@ use std::{net::Ipv4Addr, sync::Arc}; use etherparse::{InternetSlice, SlicedPacket}; +use tap::TapFallible; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - sync::mpsc::{self, UnboundedSender}, + sync::mpsc::{self}, }; use crate::{ + event::Event, setup::{TUN_BASE_NAME, TUN_DEVICE_ADDRESS, TUN_DEVICE_NETMASK}, - ActivePeers, + udp_listener::PeersByIp, + TunTaskTx, }; fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun { @@ -25,7 +28,7 @@ fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> t .expect("Failed to setup tun device, do you have permission?") } -pub fn start_tun_device(_active_peers: Arc) -> UnboundedSender> { +pub(crate) fn start_tun_device(peers_by_ip: Arc>) -> TunTaskTx { let tun = setup_tokio_tun_device( format!("{}%d", TUN_BASE_NAME).as_str(), TUN_DEVICE_ADDRESS.parse().unwrap(), @@ -37,6 +40,7 @@ pub fn start_tun_device(_active_peers: Arc) -> UnboundedSender>(); + let tun_task_tx = TunTaskTx(tun_task_tx); tokio::spawn(async move { let mut buf = [0u8; 1024]; @@ -55,8 +59,16 @@ pub fn start_tun_device(_active_peers: Arc) -> UnboundedSender {dst_addr}, {len} bytes)"); - // TODO: route packet to the correct peer. - log::info!("...forward packet to the correct peer (NOT YET IMPLEMENTED)"); + // Route packet to the correct peer. + if let Some(peer_tx) = peers_by_ip.lock().unwrap().longest_match(dst_addr).map(|(_, tx)| tx) { + log::info!("Forward packet to wg tunnel"); + peer_tx + .send(Event::IpPacket(packet.to_vec().into())) + .tap_err(|err| log::error!("{err}")) + .unwrap(); + } else { + log::info!("No peer found, packet dropped"); + } }, Err(err) => { log::info!("iface: read error: {err}"); diff --git a/common/wireguard/src/setup.rs b/common/wireguard/src/setup.rs index 1df9b559769..3fd948f487e 100644 --- a/common/wireguard/src/setup.rs +++ b/common/wireguard/src/setup.rs @@ -1,3 +1,5 @@ +use std::net::IpAddr; + use base64::{engine::general_purpose, Engine as _}; use boringtun::x25519; use log::info; @@ -15,35 +17,47 @@ pub const TUN_DEVICE_NETMASK: &str = "255.255.255.0"; // Corresponding public key: "WM8s8bYegwMa0TJ+xIwhk+dImk2IpDUKslDBCZPizlE=" const PRIVATE_KEY: &str = "AEqXrLFT4qjYq3wmX0456iv94uM6nDj5ugp6Jedcflg="; -// The public keys of the registered peers (clients) -const PEERS: &[&str; 1] = &[ - // Corresponding private key: "ILeN6gEh6vJ3Ju8RJ3HVswz+sPgkcKtAYTqzQRhTtlo=" - "NCIhkgiqxFx1ckKl3Zuh595DzIFl8mxju1Vg995EZhI=", - // Another key - // "mxV/mw7WZTe+0Msa0kvJHMHERDA/cSskiZWQce+TdEs=", -]; +// The public keys of the registered peer (clients) +// Corresponding private key: "ILeN6gEh6vJ3Ju8RJ3HVswz+sPgkcKtAYTqzQRhTtlo=" +const PEER: &str = "NCIhkgiqxFx1ckKl3Zuh595DzIFl8mxju1Vg995EZhI="; -pub fn init_static_dev_keys() -> (x25519::StaticSecret, x25519::PublicKey) { - // TODO: this is a temporary solution for development - let static_private_bytes: [u8; 32] = general_purpose::STANDARD - .decode(PRIVATE_KEY) +// The AllowedIPs for the connected peer, which is one a single IP and the same as the IP that the +// peer has configured on their side. +const ALLOWED_IPS: &str = "10.0.0.2"; + +fn decode_base64_key(base64_key: &str) -> [u8; 32] { + general_purpose::STANDARD + .decode(base64_key) .unwrap() .try_into() - .unwrap(); + .unwrap() +} + +pub fn server_static_private_key() -> x25519::StaticSecret { + // TODO: this is a temporary solution for development + let static_private_bytes: [u8; 32] = decode_base64_key(PRIVATE_KEY); let static_private = x25519::StaticSecret::try_from(static_private_bytes).unwrap(); let static_public = x25519::PublicKey::from(&static_private); info!( "wg public key: {}", general_purpose::STANDARD.encode(static_public) ); + static_private +} - // TODO: A single static public key is used for all peers during development - let peer_static_public_bytes: [u8; 32] = general_purpose::STANDARD - .decode(PEERS[0]) - .unwrap() - .try_into() - .unwrap(); +pub fn peer_static_public_key() -> x25519::PublicKey { + // A single static public key is used during development + let peer_static_public_bytes: [u8; 32] = decode_base64_key(PEER); let peer_static_public = x25519::PublicKey::try_from(peer_static_public_bytes).unwrap(); + info!( + "peer public key: {}", + general_purpose::STANDARD.encode(peer_static_public) + ); + peer_static_public +} - (static_private, peer_static_public) +pub fn peer_allowed_ips() -> ip_network::IpNetwork { + let key: IpAddr = ALLOWED_IPS.parse().unwrap(); + let cidr = 0u8; + ip_network::IpNetwork::new_truncate(key, cidr).unwrap() } diff --git a/common/wireguard/src/udp_listener.rs b/common/wireguard/src/udp_listener.rs index ed1440894df..ec15fbf3d1e 100644 --- a/common/wireguard/src/udp_listener.rs +++ b/common/wireguard/src/udp_listener.rs @@ -1,22 +1,31 @@ use std::{net::SocketAddr, sync::Arc}; +use dashmap::DashMap; use futures::StreamExt; use log::error; use nym_task::TaskClient; use tap::TapFallible; -use tokio::{net::UdpSocket, sync::mpsc::UnboundedSender}; +use tokio::{ + net::UdpSocket, + sync::mpsc::{self}, +}; use crate::{ event::Event, - setup::{WG_ADDRESS, WG_PORT}, - ActivePeers, + network_table::NetworkTable, + setup::{self, WG_ADDRESS, WG_PORT}, + TunTaskTx, }; const MAX_PACKET: usize = 65535; -pub async fn start_udp_listener( - tun_task_tx: UnboundedSender>, +pub(crate) type ActivePeers = DashMap>; +pub(crate) type PeersByIp = NetworkTable>; + +pub(crate) async fn start_udp_listener( + tun_task_tx: TunTaskTx, active_peers: Arc, + peers_by_ip: Arc>, mut task_client: TaskClient, ) -> Result<(), Box> { let wg_address = SocketAddr::new(WG_ADDRESS.parse().unwrap(), WG_PORT); @@ -24,7 +33,9 @@ pub async fn start_udp_listener( let udp_socket = Arc::new(UdpSocket::bind(wg_address).await?); // Setup some static keys for development - let (static_private, peer_static_public) = crate::setup::init_static_dev_keys(); + let static_private = setup::server_static_private_key(); + let peer_static_public = setup::peer_static_public_key(); + let peer_allowed_ips = setup::peer_allowed_ips(); tokio::spawn(async move { // Each tunnel is run in its own task, and the task handle is stored here so we can remove @@ -44,6 +55,7 @@ pub async fn start_udp_listener( Ok(addr) => { log::info!("Removing peer: {addr:?}"); active_peers.remove(&addr); + // TODO: remove from peers_by_ip } Err(err) => { error!("WireGuard UDP listener: error receiving shutdown from peer: {err}"); @@ -61,13 +73,22 @@ pub async fn start_udp_listener( .unwrap(); } else { log::info!("udp: received {len} bytes from {addr} from unknown peer, starting tunnel"); - let (join_handle, peer_tx) = crate::tun::start_wg_tunnel( + // TODO: this is a temporary solution for development since this + // assumes we know the peer_static_public this corresponds to. + // TODO: rework this before production! This is likely not secure! + log::warn!("Assuming peer_static_public is known"); + log::warn!("SECURITY: Rework me to do proper handshake before creating the tunnel!"); + let (join_handle, peer_tx) = crate::wg_tunnel::start_wg_tunnel( addr, udp_socket.clone(), static_private.clone(), peer_static_public, + peer_allowed_ips, tun_task_tx.clone(), ); + + peers_by_ip.lock().unwrap().insert(peer_allowed_ips, peer_tx.clone()); + peer_tx.send(Event::WgPacket(buf[..len].to_vec().into())) .tap_err(|err| log::error!("{err}")) .unwrap(); diff --git a/common/wireguard/src/tun.rs b/common/wireguard/src/wg_tunnel.rs similarity index 75% rename from common/wireguard/src/tun.rs rename to common/wireguard/src/wg_tunnel.rs index 3dbafecf322..e164d9f7f85 100644 --- a/common/wireguard/src/tun.rs +++ b/common/wireguard/src/wg_tunnel.rs @@ -6,7 +6,6 @@ use boringtun::{ x25519, }; use bytes::Bytes; -use etherparse::{InternetSlice, SlicedPacket}; use log::{debug, error, info, warn}; use tap::TapFallible; use tokio::{ @@ -15,7 +14,7 @@ use tokio::{ time::timeout, }; -use crate::{error::WgError, event::Event}; +use crate::{error::WgError, event::Event, network_table::NetworkTable, TunTaskTx}; const MAX_PACKET: usize = 65535; @@ -27,10 +26,10 @@ pub struct WireGuardTunnel { udp: Arc, // Peer endpoint - endpoint: SocketAddr, + endpoint: Arc>, - // The source address of the last packet received from the peer - source_addr: Arc>>, + // AllowedIPs for this peer + allowed_ips: NetworkTable<()>, // `boringtun` tunnel, used for crypto & WG protocol wg_tunnel: Arc>, @@ -40,7 +39,7 @@ pub struct WireGuardTunnel { close_rx: broadcast::Receiver<()>, // Send data to the task that handles sending data through the tun device - tun_task_tx: mpsc::UnboundedSender>, + tun_task_tx: TunTaskTx, } impl Drop for WireGuardTunnel { @@ -51,12 +50,13 @@ impl Drop for WireGuardTunnel { } impl WireGuardTunnel { - pub fn new( + pub(crate) fn new( udp: Arc, endpoint: SocketAddr, static_private: x25519::StaticSecret, peer_static_public: x25519::PublicKey, - tunnel_tx: mpsc::UnboundedSender>, + peer_allowed_ips: ip_network::IpNetwork, + tunnel_tx: TunTaskTx, ) -> (Self, mpsc::UnboundedSender) { let local_addr = udp.local_addr().unwrap(); let peer_addr = udp.peer_addr(); @@ -85,11 +85,14 @@ impl WireGuardTunnel { // Signal close tunnel let (close_tx, close_rx) = broadcast::channel(1); + let mut allowed_ips = NetworkTable::new(); + allowed_ips.insert(peer_allowed_ips, ()); + let tunnel = WireGuardTunnel { peer_rx, udp, - endpoint, - source_addr: Default::default(), + endpoint: Arc::new(tokio::sync::RwLock::new(endpoint)), + allowed_ips, wg_tunnel, close_tx, close_rx, @@ -134,7 +137,7 @@ impl WireGuardTunnel { }, } } - info!("WireGuard tunnel ({}): closed", self.endpoint); + info!("WireGuard tunnel ({}): closed", self.endpoint.read().await); } async fn wg_tunnel_lock(&self) -> Result, WgError> { @@ -143,16 +146,11 @@ impl WireGuardTunnel { .map_err(|_| WgError::UnableToGetTunnel) } - fn set_source_addr(&self, source_addr: std::net::Ipv4Addr) { - let to_update = { - let stored_source_addr = self.source_addr.read().unwrap(); - stored_source_addr - .map(|sa| sa != source_addr) - .unwrap_or(true) - }; - if to_update { - log::info!("wg tunnel set_source_addr: {source_addr}"); - *self.source_addr.write().unwrap() = Some(source_addr); + #[allow(unused)] + async fn set_endpoint(&self, addr: SocketAddr) { + if *self.endpoint.read().await != addr { + log::info!("wg tunnel update endpoint: {addr}"); + *self.endpoint.write().await = addr; } } @@ -161,8 +159,9 @@ impl WireGuardTunnel { let mut tunnel = self.wg_tunnel_lock().await?; match tunnel.decapsulate(None, data, &mut send_buf) { TunnResult::WriteToNetwork(packet) => { - log::info!("udp: send {} bytes to {}", packet.len(), self.endpoint); - if let Err(err) = self.udp.send_to(packet, self.endpoint).await { + let endpoint = self.endpoint.read().await; + log::info!("udp: send {} bytes to {}", packet.len(), *endpoint); + if let Err(err) = self.udp.send_to(packet, *endpoint).await { error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {err:?}"); }; // Flush pending queue @@ -170,8 +169,8 @@ impl WireGuardTunnel { let mut send_buf = [0u8; MAX_PACKET]; match tunnel.decapsulate(None, &[], &mut send_buf) { TunnResult::WriteToNetwork(packet) => { - log::info!("udp: send {} bytes to {}", packet.len(), self.endpoint); - if let Err(err) = self.udp.send_to(packet, self.endpoint).await { + log::info!("udp: send {} bytes to {}", packet.len(), *endpoint); + if let Err(err) = self.udp.send_to(packet, *endpoint).await { error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {err:?}"); break; }; @@ -182,14 +181,23 @@ impl WireGuardTunnel { } } } - TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { - let headers = SlicedPacket::from_ip(packet).unwrap(); - let (source_addr, _destination_addr) = match headers.ip.unwrap() { - InternetSlice::Ipv4(ip, _) => (ip.source_addr(), ip.destination_addr()), - InternetSlice::Ipv6(_, _) => unimplemented!(), - }; - self.set_source_addr(source_addr); - self.tun_task_tx.send(packet.to_vec()).unwrap(); + TunnResult::WriteToTunnelV4(packet, addr) => { + // TODO: once the flow is redone, we should add updating the endpoint dynamically + // self.set_endpoint(addr); + if self.allowed_ips.longest_match(addr).is_some() { + self.tun_task_tx.send(packet.to_vec()).unwrap(); + } else { + warn!("Packet from {addr} not in allowed_ips"); + } + } + TunnResult::WriteToTunnelV6(packet, addr) => { + // TODO: once the flow is redone, we should add updating the endpoint dynamically + // self.set_endpoint(addr); + if self.allowed_ips.longest_match(addr).is_some() { + self.tun_task_tx.send(packet.to_vec()).unwrap(); + } else { + warn!("Packet (v6) from {addr} not in allowed_ips"); + } } TunnResult::Done => { debug!("WireGuard: decapsulate done"); @@ -209,9 +217,10 @@ impl WireGuardTunnel { encapsulated_packet.len() ); - info!("consume_eth: send to {}: {}", self.endpoint, data.len()); + let endpoint = self.endpoint.read().await; + info!("consume_eth: send to {}: {}", *endpoint, data.len()); self.udp - .send_to(&encapsulated_packet, self.endpoint) + .send_to(&encapsulated_packet, *endpoint) .await .unwrap(); } @@ -244,12 +253,9 @@ impl WireGuardTunnel { async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) { match result { TunnResult::WriteToNetwork(packet) => { - log::info!( - "routine: write to network: {}: {}", - self.endpoint, - packet.len() - ); - if let Err(err) = self.udp.send_to(packet, self.endpoint).await { + let endpoint = self.endpoint.read().await; + log::info!("routine: write to network: {}: {}", endpoint, packet.len()); + if let Err(err) = self.udp.send_to(packet, *endpoint).await { error!("routine: failed to send packet: {err:?}"); }; } @@ -276,18 +282,25 @@ impl WireGuardTunnel { } } -pub fn start_wg_tunnel( +pub(crate) fn start_wg_tunnel( endpoint: SocketAddr, udp: Arc, static_private: x25519::StaticSecret, peer_static_public: x25519::PublicKey, - tunnel_tx: mpsc::UnboundedSender>, + peer_allowed_ips: ip_network::IpNetwork, + tunnel_tx: TunTaskTx, ) -> ( tokio::task::JoinHandle, mpsc::UnboundedSender, ) { - let (mut tunnel, peer_tx) = - WireGuardTunnel::new(udp, endpoint, static_private, peer_static_public, tunnel_tx); + let (mut tunnel, peer_tx) = WireGuardTunnel::new( + udp, + endpoint, + static_private, + peer_static_public, + peer_allowed_ips, + tunnel_tx, + ); let join_handle = tokio::spawn(async move { tunnel.spin_off().await; endpoint