From 89e66bcd107206a7a6a876640a9f7355e989fc4a Mon Sep 17 00:00:00 2001 From: containerscrew Date: Tue, 3 Dec 2024 19:17:59 +0100 Subject: [PATCH] Working with rule implementation --- Cargo.lock | 8 +- nflux-common/src/lib.rs | 38 ++++---- nflux-ebpf/src/main.rs | 207 +++++++++++----------------------------- nflux.toml | 50 ++++++---- nflux/src/config.rs | 80 ---------------- nflux/src/lib.rs | 5 +- nflux/src/main.rs | 207 ++++++++++++++++++++++------------------ 7 files changed, 226 insertions(+), 369 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4400fd7..8d9997d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "allocator-api2" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "anstream" @@ -1154,9 +1154,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.1" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "libc", diff --git a/nflux-common/src/lib.rs b/nflux-common/src/lib.rs index 7af7ebc..5bf71d1 100644 --- a/nflux-common/src/lib.rs +++ b/nflux-common/src/lib.rs @@ -13,34 +13,36 @@ pub struct ConnectionEvent { pub protocol: u8, // 6 for TCP, 17 for UDP } -// #[repr(C)] -// #[derive(Clone, Copy, Debug, PartialEq, Eq)] -// pub struct GlobalFirewallRules { -// pub icmp_enabled: u8, -// pub allowed_ipv4: [u32; MAX_ALLOWED_IPV4], -// pub allowed_ports: [u32; MAX_ALLOWED_PORTS], -// } - -// #[cfg(feature = "user")] -// pub mod user { -// use super::*; - -// unsafe impl aya::Pod for GlobalFirewallRules {} -// } - #[repr(C)] -#[derive(Clone, Copy, Debug)] -pub struct Ipv4Rule { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct IpRule { pub action: u8, // 0 = deny, 1 = allow pub ports: [u16; 16], // Up to 16 ports pub protocol: u8, // 6 = TCP, 17 = UDP + pub priority: u32, // Lower number means higher priority +} + +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct LpmKeyIpv4 { + pub prefix_len: u32, + pub ip: u32, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct LpmKeyIpv6 { + pub prefix_len: u32, + pub ip: [u8; 16], } #[cfg(feature = "user")] pub mod user { use super::*; - unsafe impl aya::Pod for Ipv4Rule {} + unsafe impl aya::Pod for IpRule {} + unsafe impl aya::Pod for LpmKeyIpv4 {} + unsafe impl aya::Pod for LpmKeyIpv6 {} } // Define the default configuration if the user does not provide one diff --git a/nflux-ebpf/src/main.rs b/nflux-ebpf/src/main.rs index fb59936..8f2552f 100644 --- a/nflux-ebpf/src/main.rs +++ b/nflux-ebpf/src/main.rs @@ -2,14 +2,14 @@ #![no_main] #![allow(nonstandard_style, dead_code)] -use aya_ebpf::helpers::bpf_ktime_get_ns; +use aya_ebpf::maps::lpm_trie::Key; +use aya_ebpf::maps::LpmTrie; use aya_ebpf::{ bindings::xdp_action, macros::{map, xdp}, - maps::{LruHashMap, PerfEventArray}, + maps::PerfEventArray, programs::XdpContext, }; - use core::mem; use network_types::{ eth::{EthHdr, EtherType}, @@ -17,8 +17,7 @@ use network_types::{ tcp::TcpHdr, udp::UdpHdr, }; - -use nflux_common::{ConnectionEvent, Ipv4Rule}; +use nflux_common::{ConnectionEvent, IpRule, LpmKeyIpv4, LpmKeyIpv6}; #[cfg(not(test))] #[panic_handler] @@ -26,11 +25,11 @@ fn panic(_info: &core::panic::PanicInfo) -> ! { loop {} } -// #[map] -// static GLOBAL_FIREWALL_RULES: Array = Array::with_max_entries(1, 0); +#[map] +static IPV4_RULES: LpmTrie = LpmTrie::with_max_entries(1024, 0); #[map] -static IPV4_RULES: LruHashMap = LruHashMap::with_max_entries(1024, 0); +static IPV6_RULES: LpmTrie = LpmTrie::with_max_entries(1024, 0); #[map] static CONNECTION_EVENTS: PerfEventArray = PerfEventArray::new(0); @@ -56,167 +55,73 @@ unsafe fn ptr_at(ctx: &XdpContext, offset: usize) -> Result<*const T, ()> { Ok((start + offset) as *const T) } -// Check if a port is allowed -fn is_port_allowed(global_firewall_rules: &GlobalFirewallRules, port: u16) -> bool { - for &allowed_port in &global_firewall_rules.allowed_ports { - if allowed_port == 0 { - // Stop if we hit an uninitialized entry (assuming 0 indicates unused entries) - break; - } - if port as u32 == allowed_port { - return true; - } - } - false -} - -// Check if an IP address is allowed -fn is_ipv4_allowed(app_config: &GlobalFirewallRules, ip: u32) -> bool { - for &allowed_ip in &app_config.allowed_ipv4 { - if allowed_ip == 0 { - // Stop if we hit an uninitialized entry (assuming 0 indicates unused entries) - break; - } - if ip == allowed_ip { - return true; - } - } - false -} - -// Helper function to get the current time in nanoseconds -fn current_time_ns() -> u64 { - unsafe { bpf_ktime_get_ns() } -} - -#[repr(C)] -struct IpPort { - ip: u32, - port: u16, -} - -#[map] -static RECENT_LOGS: LruHashMap = LruHashMap::with_max_entries(1024, 0); - -// Function to check if we should log a dropped SYN packet to avoid excessive logging -fn should_log(ip: u32, port: u16, log_interval_secs: u64) -> bool { - let key = IpPort { ip, port }; - let now = current_time_ns(); - - unsafe { - if let Some(&last_logged) = RECENT_LOGS.get(&key) { - // Only log if more than 5 seconds have passed - if now - last_logged < log_interval_secs * 1_000_000_000 { - return false; - } - } - } - - // Update the map with the new timestamp - RECENT_LOGS.insert(&key, &now, 0).ok(); - true -} - -fn get_global_firewall_rules() -> &'static GlobalFirewallRules { - GLOBAL_FIREWALL_RULES.get(0).unwrap() -} - -fn log_new_connection( - ctx: XdpContext, - src_addr: u32, - dst_port: u16, - protocol: u8, - log_interval_secs: u64, -) { +fn log_new_connection(ctx: XdpContext, src_addr: u32, dst_port: u16, protocol: u8) { let event = ConnectionEvent { src_addr, dst_port, protocol, }; - if should_log(src_addr, dst_port, log_interval_secs) { - CONNECTION_EVENTS.output(&ctx, &event, 0); - } + CONNECTION_EVENTS.output(&ctx, &event, 0); } fn start_nflux(ctx: XdpContext) -> Result { let ethhdr: *const EthHdr = unsafe { ptr_at(&ctx, 0)? }; - // Get global firewall rules - let global_firewall_rules = get_global_firewall_rules(); - match unsafe { (*ethhdr).ether_type } { EtherType::Ipv4 => { let ipv4hdr: *const Ipv4Hdr = unsafe { ptr_at(&ctx, EthHdr::LEN)? }; - let source = u32::from_be(unsafe { (*ipv4hdr).src_addr }); + let source_ip = u32::from_be(unsafe { (*ipv4hdr).src_addr }); let proto = unsafe { (*ipv4hdr).proto }; - match proto { - IpProto::Tcp => { - // Parse the TCP header - let tcphdr: *const TcpHdr = - unsafe { ptr_at(&ctx, EthHdr::LEN + Ipv4Hdr::LEN)? }; - let dst_port = u16::from_be(unsafe { (*tcphdr).dest }); - - // if is_port_allowed(&global_firewall_rules, dst_port) { - // log_new_connection(ctx, source, dst_port, 6, 5); - // return Ok(xdp_action::XDP_PASS); - // } - - // // Check if the IP address is allowed - // if is_ipv4_allowed(&global_firewall_rules, source) { - // log_new_connection(ctx, source, dst_port, 6, 5); - // return Ok(xdp_action::XDP_PASS); - // } - - // // Deny incoming connections, except SYN-ACK packets - // if unsafe { (*tcphdr).syn() == 1 && (*tcphdr).ack() == 0 } { - // // Block unsolicited incoming SYN packets (deny incoming connections) - // return Ok(xdp_action::XDP_DROP); - // } else if unsafe { (*tcphdr).ack() == 1 } { - // // Permit ACK packets (responses to outgoing connections) - // log_new_connection(ctx, source, dst_port, 6, 5); - // return Ok(xdp_action::XDP_PASS); - // } - - Ok(xdp_action::XDP_DROP) - } - IpProto::Udp => { - // Parse UDP header - let udphdr: *const UdpHdr = - unsafe { ptr_at(&ctx, EthHdr::LEN + Ipv4Hdr::LEN)? }; - let dst_port = u16::from_be(unsafe { (*udphdr).dest }); - let src_port = u16::from_be(unsafe { (*udphdr).source }); - - // If the source port (DNS) is 53, allow the packet. Internet connection - // if src_port == 53 { - // return Ok(xdp_action::XDP_PASS); - // } - - // // Check if the IP address is blocked - // if is_ipv4_allowed(&global_firewall_rules, source) { - // log_new_connection(ctx, source, dst_port, 6, 5); - // return Ok(xdp_action::XDP_PASS); - // } - - // // Check allowed ports - // if is_port_allowed(&global_firewall_rules, dst_port) { - // log_new_connection(ctx, source, dst_port, 6, 5); - // return Ok(xdp_action::XDP_PASS); - // } - - Ok(xdp_action::XDP_DROP) + let key = Key::new( + 32, + LpmKeyIpv4 { + prefix_len: 32, + ip: source_ip, + }, + ); + + if let Some(rule) = IPV4_RULES.get(&key) { + match proto { + IpProto::Tcp => { + let tcphdr: *const TcpHdr = + unsafe { ptr_at(&ctx, EthHdr::LEN + Ipv4Hdr::LEN)? }; + let dst_port = u16::from_be(unsafe { (*tcphdr).dest }); + + if rule.ports.contains(&dst_port) { + if rule.action == 1 { + log_new_connection(ctx, source_ip, dst_port, 6); + return Ok(xdp_action::XDP_PASS); + } + } + return Ok(xdp_action::XDP_DROP); + } + IpProto::Udp => { + let udphdr: *const UdpHdr = + unsafe { ptr_at(&ctx, EthHdr::LEN + Ipv4Hdr::LEN)? }; + let dst_port = u16::from_be(unsafe { (*udphdr).dest }); + + if rule.ports.contains(&dst_port) { + if rule.action == 1 { + log_new_connection(ctx, source_ip, dst_port, 17); + return Ok(xdp_action::XDP_PASS); + } + } + return Ok(xdp_action::XDP_DROP); + } + IpProto::Icmp => { + if rule.action == 1 { + log_new_connection(ctx, source_ip, 0, 1); + return Ok(xdp_action::XDP_PASS); + } + return Ok(xdp_action::XDP_DROP); + } + _ => return Ok(xdp_action::XDP_DROP), } - // IpProto::Icmp => { - // if global_firewall_rules.allow_icmp == 1 { - // log_new_connection(ctx, source, 0, 1, 5); - // Ok(xdp_action::XDP_PASS) - // } else { - // Ok(xdp_action::XDP_DROP) - // } - // } - _ => Ok(xdp_action::XDP_DROP), } + + Ok(xdp_action::XDP_DROP) } _ => Ok(xdp_action::XDP_DROP), } diff --git a/nflux.toml b/nflux.toml index daea241..0d65833 100644 --- a/nflux.toml +++ b/nflux.toml @@ -1,23 +1,35 @@ [firewall] -# Applies for ipv4 and ipv6 -# Global means, if you put here an ip will be able to access every port, tcp and udp -# enabled = true/false #TODO: Implement this -icmp_enabled = true -interface_name = "wlp2s0" -log_level = "info" # trace, debug, info, warn or error. Defaults to info if not set -log_type = "text" # text or json. Defaults to text if not set +# TODO: Add support for multiple interfaces +interface_names = ["wlp2s0", "eth0"] +log_level = "info" # trace, debug, info, warn, or error. Defaults to info if not set +log_type = "text" # text or json. Defaults to text if not set +# TODO +# default_action = "deny" # global default action if no specific rule matches -[firewall.ipv4_rules] -# This is more finetuned, you can specify which ip can access which port -"192.168.0.4" = { action = "deny", ports = [80], protocol = "tcp" } -"192.168.0.11" = { action = "allow", ports = [53], protocol = "udp"} -"192.168.0.50" = { action = "deny", ports = [22, 443], protocol = "tcp" } -"192.168.0.100" = { action = "allow", ports = [80, 8080], protocol = "tcp"} +[ip_rules] +# Fine-tuned rules for IP-based filtering +"192.168.0.0/24" = { priority = 1, action = "deny", ports = [22], protocol = "tcp" } +"192.168.0.170/32" = { priority = 1, action = "allow", ports = [22], protocol = "tcp", log = true, description = "Block SSH for single IP" } +# "192.168.0.170/24" = { priority = 2, action = "deny", ports = [22], protocol = "tcp", log = false, description = "Deny SSH from entire subnet" } +# "2001:0db8:85a3:0000:0000:8a2e:0370:7334" = { action = "deny", ports = [80], protocol = "tcp" } -[firewall.ipv6_rules] -"2001:0db8:85a3:0000:0000:8a2e:0370:7334" = { action = "deny", ports = [80], protocol = "tcp" } +[icmp_rules] +# Rules for ICMP traffic +"192.168.0.1/24" = { action = "deny", protocol = "icmp" } +"192.168.0.88/24" = { action = "allow", protocol = "icmp" } +"192.168.0.22/24" = { action = "deny", protocol = "icmp" } -[firewall.icmp_rules] -"192.168.0.1" = { action = "deny" } -"192.168.0.88" = { action = "allow" } -"192.168.0.22" = { action = "deny" } +[mac_rules] +# Rules for MAC address filtering +"00:0a:95:9d:68:16" = { action = "allow" } +"00:0a:95:9d:68:17" = { action = "deny" } + +[logging] +log_denied_packets = true +log_allowed_packets = false +log_format = "json" +log_file = "/var/log/firewall.log" + +[failsafe] +# Failsafe rule for unmatched traffic +action = "log" diff --git a/nflux/src/config.rs b/nflux/src/config.rs index abfef6b..8b13789 100644 --- a/nflux/src/config.rs +++ b/nflux/src/config.rs @@ -1,81 +1 @@ -use serde::Deserialize; -use std::collections::HashMap; -use std::env; -use std::fs; -// Enum to restrict `action` values -#[derive(Debug, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] // Allows "deny" and "allow" as lowercase in TOML -pub enum Action { - Allow, - Deny, -} - -#[derive(Debug, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] // Allows "deny" and "allow" as lowercase in TOML -pub enum Protocol { - Tcp, - Udp, -} - -#[derive(Deserialize, Debug)] -pub struct FirewallGlobalConfig { - pub icmp_enabled: bool, - pub interface_name: String, - pub log_level: String, - pub log_type: String, -} - -#[derive(Deserialize, Debug)] -pub struct FirewallIpv4Rules { - pub action: Action, - pub ports: Vec, - pub protocol: Protocol, -} - -#[derive(Deserialize, Debug)] -pub struct FirewallIpv6Rules { - pub action: Action, - pub ports: Vec, - pub protocol: Protocol, -} - -#[derive(Deserialize, Debug)] -pub struct IcmpRules { - pub action: Action, -} - -#[derive(Deserialize, Debug)] -pub struct FirewallConfig { - pub firewall: FirewallGlobalConfig, - pub ipv4_rules: HashMap, - pub ipv6_rules: HashMap, - pub icmp_rules: HashMap, -} - -#[derive(Deserialize, Debug)] -pub struct Config { - pub config: FirewallConfig, -} - -impl Config { - /// Load the configuration from a file, defaulting to `/etc/nflux/nflux.toml` if not specified - pub fn load() -> Self { - let config_file = env::var("NFLUX_CONFIG_FILE_PATH") - .unwrap_or_else(|_| "/etc/nflux/nflux.toml".to_string()); - - let config_content = match fs::read_to_string(&config_file) { - Ok(content) => content, - Err(e) => { - panic!("Failed to read configuration file {}: {}", config_file, e); - } - }; - - match toml::from_str(&config_content) { - Ok(config) => config, - Err(e) => { - panic!("Failed to parse configuration file {}: {}", config_file, e); - } - } - } -} diff --git a/nflux/src/lib.rs b/nflux/src/lib.rs index 97d7d2a..ade8fa3 100644 --- a/nflux/src/lib.rs +++ b/nflux/src/lib.rs @@ -5,10 +5,7 @@ mod utils; // Dependencies pub use config::Action; -pub use config::{ - Config, FirewallConfig, FirewallGlobalConfig, FirewallIpv4Rules, FirewallIpv6Rules, IcmpRules, - Protocol, -}; +pub use config::{FirewallConfig, FirewallGlobalConfig, IcmpRules, Protocol}; pub use core::set_mem_limit; /// RXH version. diff --git a/nflux/src/main.rs b/nflux/src/main.rs index 085347f..263abe6 100644 --- a/nflux/src/main.rs +++ b/nflux/src/main.rs @@ -2,86 +2,80 @@ mod config; mod core; mod logger; mod utils; + use crate::utils::{is_root_user, wait_for_shutdown}; use anyhow::Context; +use aya::maps::lpm_trie::Key; use aya::maps::perf::{AsyncPerfEventArrayBuffer, PerfBufferError}; -use aya::maps::{AsyncPerfEventArray, MapData}; +use aya::maps::{AsyncPerfEventArray, LpmTrie, Map, MapData}; use aya::programs::{Xdp, XdpFlags}; use aya::util::online_cpus; use aya::{include_bytes_aligned, Ebpf}; use bytes::BytesMut; +use config::{Action, FirewallConfig, Protocol, Rules}; use logger::setup_logger; -use nflux::{set_mem_limit, Action, Config, FirewallIpv4Rules, Protocol}; -use nflux_common::{ - convert_protocol, ConnectionEvent, Ipv4Rule, MAX_ALLOWED_IPV4, MAX_ALLOWED_PORTS, -}; +use nflux::set_mem_limit; +use nflux_common::{convert_protocol, ConnectionEvent, IpRule, LpmKeyIpv4, LpmKeyIpv6}; use std::collections::HashMap; -use std::net::Ipv4Addr; -use std::{env, ptr}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::ptr; use tokio::task; -use tracing::{error, info}; +use tracing::{error, info, warn}; #[tokio::main] -async fn main() -> anyhow::Result<(), anyhow::Error> { +async fn main() -> anyhow::Result<()> { // Load configuration file - let config = Config::load(); + let config = FirewallConfig::load().context("Failed to load firewall configuration")?; // Enable logging - setup_logger( - &config.config.firewall.log_level, - &config.config.firewall.log_type, - ); + setup_logger(&config.firewall.log_level, &config.firewall.log_type); - // Check if user is root. + // Ensure the program is run as root if !is_root_user() { error!("This program must be run as root."); std::process::exit(1); } - // Mem limit + // Set memory limit set_mem_limit(); // Load eBPF program let mut bpf = Ebpf::load(include_bytes_aligned!(concat!(env!("OUT_DIR"), "/nflux")))?; - // If you want to print logs from eBPF program, uncomment the following lines - // if let Err(e) = aya_log::EbpfLogger::init(&mut bpf) { - // warn!("failed to initialize eBPF logger: {}", e); - // } - - // Populate EBPF map with app config - // populate_global_rules(&mut bpf, &app_config)?; - populate_ipv4_rules(&mut bpf, &config.config.ipv4_rules) - .context("Failed to populate IPv4 rules")?; + // Populate eBPF maps with configuration data + populate_ipv4_rules(&mut bpf, &config.ip_rules)?; + // populate_ipv6_rules(&mut bpf, &config.ip_rules)?; // Attach XDP program - // TODO: check if the interface you want to attach is valid (physical) - // XDP program can only be attached to physical interfaces let program: &mut Xdp = bpf.program_mut("nflux").unwrap().try_into()?; program.load()?; - program.attach(&config.config.firewall.interface_name.as_str(), XdpFlags::default()) - .context("failed to attach the XDP program with default flags - try changing XdpFlags::default() to XdpFlags::SKB_MODE")?; + program + .attach(&config.firewall.interface_names[0], XdpFlags::default()) + .context( + "Failed to attach XDP program. Ensure the interface is physical and not virtual.", + )?; - // Some basic info + // Log startup info info!("nflux started successfully!"); info!( - "Successfully attached XDP program to iface: {}", - config.config.firewall.interface_name + "XDP program attached to interface: {:?}", + config.firewall.interface_names[0] ); - info!("Checking incoming packets..."); - let mut events = AsyncPerfEventArray::try_from(bpf.take_map("CONNECTION_EVENTS").unwrap())?; + // Start processing events from the eBPF program + let mut events = AsyncPerfEventArray::try_from( + bpf.take_map("CONNECTION_EVENTS") + .context("Failed to find CONNECTION_EVENTS map")?, + )?; let cpus = online_cpus().map_err(|(_, error)| error)?; for cpu_id in cpus { let buf = events.open(cpu_id, None)?; - task::spawn(process_events(buf, cpu_id)); } // Wait for shutdown signal wait_for_shutdown().await?; - Ok(()) } @@ -105,7 +99,7 @@ async fn process_events( cpu_id, convert_protocol(event.protocol), event.dst_port, - Ipv4Addr::from(event.src_addr) + Ipv4Addr::from(event.src_addr), ); } Err(e) => error!("Failed to parse event on CPU {}: {}", cpu_id, e), @@ -114,78 +108,105 @@ async fn process_events( } } -// Helper function to convert Vec to [u32; N] -fn convert_ipv4_vec_to_array(vec: &Vec, max_len: usize) -> [u32; MAX_ALLOWED_IPV4] { - let mut array = [0; MAX_ALLOWED_IPV4]; - for (i, ip_str) in vec.iter().take(max_len).enumerate() { - if let Ok(ip) = ip_str.parse::() { - array[i] = u32::from(ip); - } +fn parse_connection_event(buf: &BytesMut) -> anyhow::Result { + if buf.len() >= std::mem::size_of::() { + let ptr = buf.as_ptr() as *const ConnectionEvent; + let event = unsafe { ptr::read_unaligned(ptr) }; + Ok(event) + } else { + Err(anyhow::anyhow!( + "Buffer size is too small for ConnectionEvent" + )) } - array } -// fn populate_global_rules(bpf: &mut Ebpf, global_rules: &GlobalFirewallRules) -> anyhow::Result<()> { -// let mut global_rules_map: Array<_, GlobalFirewallRules> = -// Array::try_from(bpf.map_mut("GLOBAL_FIREWALL_RULES").unwrap())?; -// global_rules_map.set(0, global_rules, 0)?; -// Ok(()) -// } - -fn populate_ipv4_rules( - bpf: &mut Ebpf, - ipv4_rules: &HashMap, // This comes from the `Config` -) -> anyhow::Result<()> { - let mut map: aya::maps::HashMap<_, u32, Ipv4Rule> = aya::maps::HashMap::try_from( +fn populate_ipv4_rules(bpf: &mut Ebpf, ip_rules: &HashMap) -> anyhow::Result<()> { + let mut ipv4_map: LpmTrie<&mut MapData, LpmKeyIpv4, IpRule> = LpmTrie::try_from( bpf.map_mut("IPV4_RULES") - .context("IPV4_RULES map not found")?, + .context("Failed to find IPV4_RULES map")?, )?; - for (ip_str, rule) in ipv4_rules { - // Parse IP string into u32 - let ip: u32 = ip_str.parse::()?.into(); + // Sort rules by priority + let mut sorted_rules: Vec<_> = ip_rules.iter().collect(); + sorted_rules.sort_by_key(|(_, rule)| rule.priority); - // Prepare ports array - let mut ports = [0u16; 16]; - for (i, &port) in rule.ports.iter().enumerate().take(16) { - ports[i] = port as u16; - } + for (cidr, rule) in sorted_rules { + println!("Loading rule: CIDR={}, {:?}", cidr, rule); + let (ip, prefix_len) = parse_cidr_v4(cidr)?; + let ip_rule = prepare_ip_rule(rule)?; - // Create Ipv4Rule struct - let ipv4_rule = Ipv4Rule { - action: if rule.action == Action::Allow { 1 } else { 0 }, - ports, - protocol: if rule.protocol == Protocol::Tcp { - 6 - } else { - 17 + let key = Key::new( + prefix_len, + LpmKeyIpv4 { + prefix_len, + ip: ip.into(), }, - }; - - // Insert into the map - map.insert(ip, ipv4_rule, 0)?; + ); + ipv4_map + .insert(&key, &ip_rule, 0) + .context("Failed to insert IPv4 rule")?; } Ok(()) } -fn convert_ports_vec_to_array(vec: &Vec, max_len: usize) -> [u32; MAX_ALLOWED_PORTS] { - let mut array = [0; MAX_ALLOWED_PORTS]; - for (i, &port) in vec.iter().take(max_len).enumerate() { - array[i] = port; +fn prepare_ip_rule(rule: &Rules) -> anyhow::Result { + let mut ports = [0u16; 16]; + for (i, &port) in rule.ports.iter().enumerate().take(16) { + ports[i] = port as u16; } - array + + Ok(IpRule { + action: match rule.action { + Action::Allow => 1, + Action::Deny => 0, + _ => { + warn!("Unsupported action: {:?}", rule.action); + return Err(anyhow::anyhow!("Unsupported action")); + } + }, + ports, + protocol: match rule.protocol { + Protocol::Tcp => 6, + Protocol::Udp => 17, + Protocol::Icmp => 1, + }, + priority: rule.priority, + }) } -fn parse_connection_event(buf: &BytesMut) -> anyhow::Result { - if buf.len() >= std::mem::size_of::() { - let ptr = buf.as_ptr() as *const ConnectionEvent; - // Safety: we've confirmed the buffer is large enough - let event = unsafe { ptr::read_unaligned(ptr) }; - Ok(event) - } else { - Err(anyhow::anyhow!( - "Buffer size is too small for ConnectionEvent" - )) +// fn populate_ipv6_rules(bpf: &mut Ebpf, ip_rules: &HashMap) -> anyhow::Result<()> { +// let mut ipv6_map: LpmTrie<&mut MapData, LpmKeyIpv6, IpRule> = LpmTrie::try_from( +// bpf.map_mut("IPV6_RULES").context("Failed to find IPV4_RULES map")?, +// )?; + +// for (cidr, rule) in ip_rules { +// let (ip, prefix_len) = parse_cidr_v6(cidr)?; +// let ip_rule = prepare_ip_rule(rule)?; + +// let key = Key::new(prefix_len, LpmKeyIpv6 { prefix_len, ip: ip.into() }); +// ipv6_map.insert(&key, &ip_rule, 0).context("Failed to insert IPv6 rule")?; +// } + +// Ok(()) +// } + +fn parse_cidr_v4(cidr: &str) -> anyhow::Result<(Ipv4Addr, u32)> { + let parts: Vec<&str> = cidr.split('/').collect(); + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid CIDR format: {}", cidr)); } + let ip = parts[0].parse::()?; + let prefix_len = parts[1].parse::()?; + Ok((ip, prefix_len)) } + +// fn parse_cidr_v6(cidr: &str) -> anyhow::Result<(Ipv6Addr, u32)> { +// let parts: Vec<&str> = cidr.split('/').collect(); +// if parts.len() != 2 { +// return Err(anyhow::anyhow!("Invalid CIDR format: {}", cidr)); +// } +// let ip = parts[0].parse::()?; +// let prefix_len = parts[1].parse::()?; +// Ok((ip, prefix_len)) +// }