From 8b2326ef2e1a1e56c798e0751384772bdda134d9 Mon Sep 17 00:00:00 2001 From: Hasan Date: Tue, 14 Jan 2025 15:05:25 +0100 Subject: [PATCH] Use a separate function to Q encapsulation --- neptun/src/device/mod.rs | 27 ++++++++++++-- neptun/src/noise/mod.rs | 74 +++++++++++++++++++++++++------------ neptun/src/noise/session.rs | 4 +- neptun/src/noise/timers.rs | 21 +++-------- xray/src/main.rs | 2 +- 5 files changed, 81 insertions(+), 47 deletions(-) diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index bd35936..8af9368 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -716,16 +716,21 @@ impl Device { Box::new(|d, t| { let peer_map = &d.peers; - match (d.udp4.as_ref(), d.udp6.as_ref()) { - (Some(_), Some(_)) => (), + let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) { + (Some(udp4), Some(udp6)) => (udp4, udp6), _ => return Action::Continue, }; // Go over each peer and invoke the timer function for peer in peer_map.values() { + let endpoint_addr = match peer.endpoint().addr { + Some(addr) => addr, + None => continue, + }; + let res = { let mut tun = peer.tunnel.lock(); - tun.update_timers(&mut t.dst_buf[..], peer.endpoint_ref()) + tun.update_timers(&mut t.dst_buf[..]) }; match res { TunnResult::Done => {} @@ -733,6 +738,20 @@ impl Device { peer.shutdown_endpoint(); // close open udp socket } TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), + TunnResult::WriteToNetwork(packet) => { + let res = match endpoint_addr { + SocketAddr::V4(_) => { + udp4.send_to(packet, &endpoint_addr.into()) + } + SocketAddr::V6(_) => { + udp6.send_to(packet, &endpoint_addr.into()) + } + }; + + if let Err(err) = res { + tracing::warn!(message = "Failed to send timers request", error = ?err, dst = ?endpoint_addr); + } + } _ => panic!("Unexpected result from update_timers"), }; } @@ -1081,7 +1100,7 @@ impl Device { let res = { let mut tun = peer.tunnel.lock(); - tun.encapsulate(len, element, iter, peer.endpoint_ref()) + tun.queue_encapsulate(len, element, iter, peer.endpoint_ref()) }; } } diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 67c7864..914fee0 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -15,7 +15,7 @@ mod timers; use crossbeam::channel::{Receiver, Sender}; use ring_buffers::{EncryptionTaskData, TX_RING_BUFFER}; -use session::Session; +use session::{Session, AEAD_SIZE, DATA_OFFSET}; use crate::noise::errors::WireGuardError; use crate::noise::handshake::Handshake; @@ -306,13 +306,46 @@ impl Tunn { } } + pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + let current = self.current; + if let Some(ref session) = self.sessions[current % N_SESSIONS] { + // Send the packet using an established session + let (packet, _) = Session::format_packet_data( + session.get_sending_key_counter(), + session.get_sending_index(), + session.get_sender_key(), + src.len(), + dst, + ); + + // Send the notification on the channel to encrypt the packet + self.mark_timer_to_update(TimerName::TimeLastPacketSent); + // Exclude Keepalive packets from timer update. + if !src.is_empty() { + self.mark_timer_to_update(TimerName::TimeLastDataPacketSent); + } + self.tx_bytes += packet.len(); + return TunnResult::WriteToNetwork(packet); + } + + if !src.is_empty() { + // If there is no session, queue the packet for future retry, + // except if it's keepalive packet, new keepalive packets will be sent when session is created. + // This prevents double keepalive packets on initiation + self.queue_packet(src); + } + + // Initiate a new handshake if none is in progress + self.format_handshake_initiation(dst, false) + } + /// Encapsulate a single packet from the tunnel interface. /// Returns TunnResult. /// /// # Panics /// Panics if dst buffer is too small. /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. - pub fn encapsulate<'a>( + pub fn queue_encapsulate<'a>( &mut self, len: usize, element: &'static mut EncryptionTaskData, @@ -335,9 +368,7 @@ impl Tunn { if len != 0 { self.mark_timer_to_update(TimerName::TimeLastDataPacketSent); } - // TODO! - Can't set the len here, as 1) wrong length. 2) don't know - // whether error would come or not. - // self.tx_bytes += packet.len(); + self.tx_bytes += len + DATA_OFFSET + AEAD_SIZE; // TODO! - This has to change let _ = self.encrypt_tx_chan.send(iter); @@ -367,12 +398,14 @@ impl Tunn { dst.endpoint = endpoint; let res = self.format_handshake_initiation(dst.data.as_mut_slice(), force_resend); match res { - NeptunResult::Done => return, - NeptunResult::Err(e) => { + TunnResult::Done => return, + TunnResult::Err(e) => { tracing::error!(message = "Handshake initiation error", error = ?e); return; } - NeptunResult::WriteToNetwork(n) => dst.res = NeptunResult::WriteToNetwork(n), + TunnResult::WriteToNetwork(buf) => { + dst.res = NeptunResult::WriteToNetwork(buf.len()) + } _ => panic!("Unexpected result from handshake initiation"), } }; @@ -572,9 +605,9 @@ impl Tunn { &mut self, dst: &'a mut [u8], force_resend: bool, - ) -> NeptunResult { + ) -> TunnResult<'a> { if self.handshake.is_in_progress() && !force_resend { - return NeptunResult::Done; + return TunnResult::Done; } if self.handshake.is_expired() { @@ -593,9 +626,9 @@ impl Tunn { self.mark_timer_to_update(TimerName::TimeLastPacketSent); self.tx_bytes += packet.len(); - NeptunResult::WriteToNetwork(packet.len()) + TunnResult::WriteToNetwork(packet) } - Err(e) => NeptunResult::Err(e), + Err(e) => TunnResult::Err(e), } } @@ -738,7 +771,6 @@ impl Tunn { #[cfg(test)] mod tests { - use std::sync::atomic::AtomicBool; #[cfg(feature = "mock-instant")] use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT}; @@ -787,14 +819,14 @@ mod tests { fn create_handshake_init(tun: &mut Tunn) -> Vec { let mut dst = vec![0u8; 2048]; let handshake_init = tun.format_handshake_initiation(&mut dst, false); - assert!(matches!(handshake_init, NeptunResult::WriteToNetwork(_))); - let handshake_init = if let NeptunResult::WriteToNetwork(sent) = handshake_init { + assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); + let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { sent } else { unreachable!(); }; - dst[..handshake_init].to_vec() + handshake_init.into() } fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec { @@ -900,14 +932,8 @@ mod tests { fn full_handshake_plus_timers() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); // Time has not yet advanced so their is nothing to do - assert!(matches!( - my_tun.update_timers(&mut [], Arc::default()), - TunnResult::Done - )); - assert!(matches!( - their_tun.update_timers(&mut [], Arc::default()), - TunnResult::Done - )); + assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); } #[test] diff --git a/neptun/src/noise/session.rs b/neptun/src/noise/session.rs index 4db3797..0f28dee 100644 --- a/neptun/src/noise/session.rs +++ b/neptun/src/noise/session.rs @@ -35,9 +35,9 @@ impl std::fmt::Debug for Session { } /// Where encrypted data resides in a data packet -const DATA_OFFSET: usize = 16; +pub const DATA_OFFSET: usize = 16; /// The overhead of the AEAD -const AEAD_SIZE: usize = 16; +pub const AEAD_SIZE: usize = 16; // Receiving buffer constants const WORD_SIZE: u64 = 64; diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index b839e85..85e62d1 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -3,17 +3,14 @@ // SPDX-License-Identifier: BSD-3-Clause use super::errors::WireGuardError; -use super::ring_buffers::TX_RING_BUFFER; -use crate::noise::{safe_duration::SafeDuration as Duration, Endpoint, Tunn, TunnResult}; +use crate::noise::{safe_duration::SafeDuration as Duration, Tunn, TunnResult}; use std::mem; use std::ops::{Index, IndexMut}; -use std::sync::atomic::{AtomicU16, Ordering}; -use std::sync::Arc; +use std::sync::atomic::AtomicU16; use std::time::SystemTime; #[cfg(feature = "mock-instant")] use mock_instant::Instant; -use parking_lot::RwLock; #[cfg(not(any( feature = "mock-instant", @@ -225,11 +222,7 @@ impl Tunn { } } - pub fn update_timers<'a>( - &mut self, - dst: &'a mut [u8], - endpoint: Arc>, - ) -> TunnResult<'a> { + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { let mut handshake_initiation_required = false; let mut keepalive_required = false; @@ -377,15 +370,11 @@ impl Tunn { } if handshake_initiation_required { - self.initiate_handshake(endpoint, true); - return TunnResult::Done; + return self.format_handshake_initiation(dst, true); } if keepalive_required { - let (element, iter) = unsafe { TX_RING_BUFFER.get_next() }; - if element.is_element_free.load(Ordering::Relaxed) { - self.encapsulate(0, element, iter, endpoint); - } + return self.encapsulate(&[], dst); } TunnResult::Done diff --git a/xray/src/main.rs b/xray/src/main.rs index cc763dc..e1ab1ae 100644 --- a/xray/src/main.rs +++ b/xray/src/main.rs @@ -119,7 +119,7 @@ async fn main() -> EyreResult<()> { let plaintext_client = Client::new(PLAINTEXT_ADDR, None, plaintext_sock); let crypto_sock = UdpSocket::bind(CRYPTO_SOCK_ADDR).await?; - let tunn = Tunn::new(peer_keys.private, wg_keys.public, None, None, 123, None) + let tunn = Tunn::new(peer_keys.private, wg_keys.public, None, None, 123, None, None, None, None) .map_err(|s| XRayError::UnexpectedTunnResult(s.to_owned()))?; let mut crypto_client = Client::new(CRYPTO_ADDR, Some(tunn), crypto_sock); crypto_client.do_handshake(WG_ADDR).await?;