diff --git a/nflux-ebpf/src/main.rs b/nflux-ebpf/src/main.rs index 23bfb6d..fad1d8c 100644 --- a/nflux-ebpf/src/main.rs +++ b/nflux-ebpf/src/main.rs @@ -3,7 +3,7 @@ #![allow(nonstandard_style, dead_code)] use aya_ebpf::maps::lpm_trie::Key; -use aya_ebpf::maps::{Array, LpmTrie}; +use aya_ebpf::maps::{Array, LpmTrie, LruHashMap}; use aya_ebpf::{ bindings::xdp_action, macros::{map, xdp}, @@ -38,6 +38,9 @@ static ICMP_RULE: Array = Array::with_max_entries(1, 0); #[map] static EGRESS_EVENT: PerfEventArray = PerfEventArray::new(0); +#[map] +static ACTIVE_CONNECTIONS: LruHashMap = LruHashMap::with_max_entries(1024, 0); + #[xdp] pub fn nflux(ctx: XdpContext) -> u32 { match start_nflux(ctx) { @@ -78,20 +81,27 @@ fn log_new_connection(ctx: XdpContext, src_addr: u32, dst_port: u16, protocol: u fn try_tc_egress(ctx: TcContext) -> Result { let ethhdr: EthHdr = ctx.load(0).map_err(|_| ())?; match ethhdr.ether_type { - EtherType::Ipv4 => {} + EtherType::Ipv4 => unsafe { + let ipv4hdr: Ipv4Hdr = ctx.load(EthHdr::LEN).map_err(|_| ())?; + let destination = u32::from_be(ipv4hdr.dst_addr); + + // Check if this destination is already active + if ACTIVE_CONNECTIONS.get(&destination).is_none() { + // Log only new connections + let event = EgressEvent { dst_ip: destination }; + EGRESS_EVENT.output(&ctx, &event, 0); + + // Mark connection as active + ACTIVE_CONNECTIONS.insert(&destination, &1, 0).map_err(|_| ())?; + } + } _ => return Ok(TC_ACT_PIPE), } - let ipv4hdr: Ipv4Hdr = ctx.load(EthHdr::LEN).map_err(|_| ())?; - let destination = u32::from_be(ipv4hdr.dst_addr); - - let event = EgressEvent { dst_ip: destination }; - - EGRESS_EVENT.output(&ctx, &event, 0); - Ok(TC_ACT_PIPE) } + fn start_nflux(ctx: XdpContext) -> Result { let ethhdr: *const EthHdr = unsafe { ptr_at(&ctx, 0)? }; diff --git a/nflux/src/main.rs b/nflux/src/main.rs index 3368bc3..d879a88 100644 --- a/nflux/src/main.rs +++ b/nflux/src/main.rs @@ -15,7 +15,7 @@ use bytes::BytesMut; use config::{Action, Nflux, Protocol, IpRules}; use core::set_mem_limit; use logger::setup_logger; -use nflux_common::{convert_protocol, ConnectionEvent, IpRule, LpmKeyIpv4, LpmKeyIpv6}; +use nflux_common::{convert_protocol, ConnectionEvent, EgressEvent, IpRule, LpmKeyIpv4, LpmKeyIpv6}; use std::collections::HashMap; use std::net::{Ipv4Addr, Ipv6Addr}; use std::ptr; @@ -82,22 +82,58 @@ async fn main() -> anyhow::Result<()> { )?; let mut egress_events = AsyncPerfEventArray::try_from( - bpf.take_map("EGRESS_EVENTS") - .context("Failed to find EGRESS_EVENTS map")?, + bpf.take_map("EGRESS_EVENT") + .context("Failed to find EGRESS_EVENT map")?, )?; let cpus = online_cpus().map_err(|(_, error)| error)?; for cpu_id in cpus { - let buf = events.open(cpu_id, None)?; - task::spawn(process_events(buf, cpu_id)); + // Spawn task for connection events + { + let buf = events.open(cpu_id, None)?; + task::spawn(process_events(buf, cpu_id)); + } + + // Spawn task for egress events + { + let buf = egress_events.open(cpu_id, None)?; + task::spawn(process_egress_events(buf, cpu_id)); + } } + // Wait for shutdown signal wait_for_shutdown().await?; Ok(()) } +async fn process_egress_events( + mut buf: AsyncPerfEventArrayBuffer, + cpu_id: u32, +) -> Result<(), PerfBufferError> { + let mut buffers = vec![BytesMut::with_capacity(1024); 10]; + + loop { + // Wait for events + let events = buf.read_events(&mut buffers).await?; + + // Process each event in the buffer + for i in 0..events.read { + let buf = &buffers[i]; + match parse_egress_event(buf) { + Ok(event) => { + info!( + "direction=outgoing ip={}", + Ipv4Addr::from(event.dst_ip) + ); + } + Err(e) => error!("Failed to parse egress event on CPU {}: {}", cpu_id, e), + } + } + } +} + async fn process_events( mut buf: AsyncPerfEventArrayBuffer, cpu_id: u32, @@ -127,6 +163,18 @@ async fn process_events( } } +fn parse_egress_event(buf: &BytesMut) -> anyhow::Result { + if buf.len() >= std::mem::size_of::() { + let ptr = buf.as_ptr() as *const EgressEvent; + let event = unsafe { ptr::read_unaligned(ptr) }; + Ok(event) + } else { + Err(anyhow::anyhow!( + "Buffer size is too small for EgressEvent" + )) + } +} + fn parse_connection_event(buf: &BytesMut) -> anyhow::Result { if buf.len() >= std::mem::size_of::() { let ptr = buf.as_ptr() as *const ConnectionEvent;