Skip to content

Commit

Permalink
Use a separate function to Q encapsulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Hasan6979 committed Jan 14, 2025
1 parent 63f60e6 commit 8b2326e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 47 deletions.
27 changes: 23 additions & 4 deletions neptun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,23 +716,42 @@ impl Device {
Box::new(|d, t| {
let peer_map = &d.peers;

match (d.udp4.as_ref(), d.udp6.as_ref()) {
(Some(_), Some(_)) => (),
let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) {
(Some(udp4), Some(udp6)) => (udp4, udp6),
_ => return Action::Continue,
};

// Go over each peer and invoke the timer function
for peer in peer_map.values() {
let endpoint_addr = match peer.endpoint().addr {
Some(addr) => addr,
None => continue,
};

let res = {
let mut tun = peer.tunnel.lock();
tun.update_timers(&mut t.dst_buf[..], peer.endpoint_ref())
tun.update_timers(&mut t.dst_buf[..])
};
match res {
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
peer.shutdown_endpoint(); // close open udp socket
}
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
TunnResult::WriteToNetwork(packet) => {
let res = match endpoint_addr {
SocketAddr::V4(_) => {
udp4.send_to(packet, &endpoint_addr.into())
}
SocketAddr::V6(_) => {
udp6.send_to(packet, &endpoint_addr.into())
}
};

if let Err(err) = res {
tracing::warn!(message = "Failed to send timers request", error = ?err, dst = ?endpoint_addr);
}
}
_ => panic!("Unexpected result from update_timers"),
};
}
Expand Down Expand Up @@ -1081,7 +1100,7 @@ impl Device {

let res = {
let mut tun = peer.tunnel.lock();
tun.encapsulate(len, element, iter, peer.endpoint_ref())
tun.queue_encapsulate(len, element, iter, peer.endpoint_ref())
};
}
}
Expand Down
74 changes: 50 additions & 24 deletions neptun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod timers;

use crossbeam::channel::{Receiver, Sender};
use ring_buffers::{EncryptionTaskData, TX_RING_BUFFER};
use session::Session;
use session::{Session, AEAD_SIZE, DATA_OFFSET};

use crate::noise::errors::WireGuardError;
use crate::noise::handshake::Handshake;
Expand Down Expand Up @@ -306,13 +306,46 @@ impl Tunn {
}
}

pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> {
let current = self.current;
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
// Send the packet using an established session
let (packet, _) = Session::format_packet_data(
session.get_sending_key_counter(),
session.get_sending_index(),
session.get_sender_key(),
src.len(),
dst,
);

// Send the notification on the channel to encrypt the packet
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
// Exclude Keepalive packets from timer update.
if !src.is_empty() {
self.mark_timer_to_update(TimerName::TimeLastDataPacketSent);
}
self.tx_bytes += packet.len();
return TunnResult::WriteToNetwork(packet);
}

if !src.is_empty() {
// If there is no session, queue the packet for future retry,
// except if it's keepalive packet, new keepalive packets will be sent when session is created.
// This prevents double keepalive packets on initiation
self.queue_packet(src);
}

// Initiate a new handshake if none is in progress
self.format_handshake_initiation(dst, false)
}

/// Encapsulate a single packet from the tunnel interface.
/// Returns TunnResult.
///
/// # Panics
/// Panics if dst buffer is too small.
/// Size of dst should be at least src.len() + 32, and no less than 148 bytes.
pub fn encapsulate<'a>(
pub fn queue_encapsulate<'a>(
&mut self,
len: usize,
element: &'static mut EncryptionTaskData,
Expand All @@ -335,9 +368,7 @@ impl Tunn {
if len != 0 {
self.mark_timer_to_update(TimerName::TimeLastDataPacketSent);
}
// TODO! - Can't set the len here, as 1) wrong length. 2) don't know
// whether error would come or not.
// self.tx_bytes += packet.len();
self.tx_bytes += len + DATA_OFFSET + AEAD_SIZE;

// TODO! - This has to change
let _ = self.encrypt_tx_chan.send(iter);
Expand Down Expand Up @@ -367,12 +398,14 @@ impl Tunn {
dst.endpoint = endpoint;
let res = self.format_handshake_initiation(dst.data.as_mut_slice(), force_resend);
match res {
NeptunResult::Done => return,
NeptunResult::Err(e) => {
TunnResult::Done => return,
TunnResult::Err(e) => {
tracing::error!(message = "Handshake initiation error", error = ?e);
return;
}
NeptunResult::WriteToNetwork(n) => dst.res = NeptunResult::WriteToNetwork(n),
TunnResult::WriteToNetwork(buf) => {
dst.res = NeptunResult::WriteToNetwork(buf.len())
}
_ => panic!("Unexpected result from handshake initiation"),
}
};
Expand Down Expand Up @@ -572,9 +605,9 @@ impl Tunn {
&mut self,
dst: &'a mut [u8],
force_resend: bool,
) -> NeptunResult {
) -> TunnResult<'a> {
if self.handshake.is_in_progress() && !force_resend {
return NeptunResult::Done;
return TunnResult::Done;
}

if self.handshake.is_expired() {
Expand All @@ -593,9 +626,9 @@ impl Tunn {
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
self.tx_bytes += packet.len();

NeptunResult::WriteToNetwork(packet.len())
TunnResult::WriteToNetwork(packet)
}
Err(e) => NeptunResult::Err(e),
Err(e) => TunnResult::Err(e),
}
}

Expand Down Expand Up @@ -738,7 +771,6 @@ impl Tunn {

#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicBool;

#[cfg(feature = "mock-instant")]
use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT};
Expand Down Expand Up @@ -787,14 +819,14 @@ mod tests {
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
assert!(matches!(handshake_init, NeptunResult::WriteToNetwork(_)));
let handshake_init = if let NeptunResult::WriteToNetwork(sent) = handshake_init {
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
sent
} else {
unreachable!();
};

dst[..handshake_init].to_vec()
handshake_init.into()
}

fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec<u8> {
Expand Down Expand Up @@ -900,14 +932,8 @@ mod tests {
fn full_handshake_plus_timers() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
// Time has not yet advanced so their is nothing to do
assert!(matches!(
my_tun.update_timers(&mut [], Arc::default()),
TunnResult::Done
));
assert!(matches!(
their_tun.update_timers(&mut [], Arc::default()),
TunnResult::Done
));
assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions neptun/src/noise/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ impl std::fmt::Debug for Session {
}

/// Where encrypted data resides in a data packet
const DATA_OFFSET: usize = 16;
pub const DATA_OFFSET: usize = 16;
/// The overhead of the AEAD
const AEAD_SIZE: usize = 16;
pub const AEAD_SIZE: usize = 16;

// Receiving buffer constants
const WORD_SIZE: u64 = 64;
Expand Down
21 changes: 5 additions & 16 deletions neptun/src/noise/timers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
// SPDX-License-Identifier: BSD-3-Clause

use super::errors::WireGuardError;
use super::ring_buffers::TX_RING_BUFFER;
use crate::noise::{safe_duration::SafeDuration as Duration, Endpoint, Tunn, TunnResult};
use crate::noise::{safe_duration::SafeDuration as Duration, Tunn, TunnResult};
use std::mem;
use std::ops::{Index, IndexMut};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use std::sync::atomic::AtomicU16;
use std::time::SystemTime;

#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
use parking_lot::RwLock;

#[cfg(not(any(
feature = "mock-instant",
Expand Down Expand Up @@ -225,11 +222,7 @@ impl Tunn {
}
}

pub fn update_timers<'a>(
&mut self,
dst: &'a mut [u8],
endpoint: Arc<RwLock<Endpoint>>,
) -> TunnResult<'a> {
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
let mut handshake_initiation_required = false;
let mut keepalive_required = false;

Expand Down Expand Up @@ -377,15 +370,11 @@ impl Tunn {
}

if handshake_initiation_required {
self.initiate_handshake(endpoint, true);
return TunnResult::Done;
return self.format_handshake_initiation(dst, true);
}

if keepalive_required {
let (element, iter) = unsafe { TX_RING_BUFFER.get_next() };
if element.is_element_free.load(Ordering::Relaxed) {
self.encapsulate(0, element, iter, endpoint);
}
return self.encapsulate(&[], dst);
}

TunnResult::Done
Expand Down
2 changes: 1 addition & 1 deletion xray/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() -> EyreResult<()> {
let plaintext_client = Client::new(PLAINTEXT_ADDR, None, plaintext_sock);

let crypto_sock = UdpSocket::bind(CRYPTO_SOCK_ADDR).await?;
let tunn = Tunn::new(peer_keys.private, wg_keys.public, None, None, 123, None)
let tunn = Tunn::new(peer_keys.private, wg_keys.public, None, None, 123, None, None, None, None)
.map_err(|s| XRayError::UnexpectedTunnResult(s.to_owned()))?;
let mut crypto_client = Client::new(CRYPTO_ADDR, Some(tunn), crypto_sock);
crypto_client.do_handshake(WG_ADDR).await?;
Expand Down

0 comments on commit 8b2326e

Please sign in to comment.