diff --git a/Cargo.toml b/Cargo.toml index 6826062..58b59f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,7 @@ strip = true [dev-dependencies] tempfile = "3.11.0" + +[lib] +name = "ratelimiter" +path = "src/ratelimiter.rs" diff --git a/src/main.rs b/src/main.rs index c569965..ed841b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ // you may not use this file except in compliance with the License. // A copy of the License has been included in the root of the repository. +mod ratelimiter; + use std::collections::HashMap; use std::fs::File; use std::io::{Error, ErrorKind, Read, Result, Write}; @@ -13,10 +15,12 @@ use std::os::fd::AsRawFd; use std::path::Path; use std::str::FromStr; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{mpsc, Arc}; +use std::sync::{mpsc, Arc, Mutex}; use std::time::Instant; use walkdir::WalkDir; +use crate::ratelimiter::RateLimiter; + use borsh::BorshDeserialize; use borsh::BorshSerialize; @@ -27,26 +31,29 @@ Usage: fastsync recv Sender options: - Address (IP and port) for the sending side to bind to and - listen for receivers. This should be the address of a - Wireguard interface if you care about confidentiality. - E.g. '100.71.154.83:7999'. + Address (IP and port) for the sending side to bind to and + listen for receivers. This should be the address of a + Wireguard interface if you care about confidentiality. + E.g. '100.71.154.83:7999'. + + [--max-bandwidth-mbps ] Specify the maximum bandwidth to use over a 1 second sliding + window, in MB/s. If unspecified, there will be no limit. - Paths of files to send. Input file paths need to be relative. - This is a safety measure to make it harder to accidentally - overwrite files in /etc and the like on the receiving end. + Paths of files to send. Input file paths need to be relative. + This is a safety measure to make it harder to accidentally + overwrite files in /etc and the like on the receiving end. Receiver options: - The address (IP and port) that the sender is listening on. - E.g. '100.71.154.83:7999'. - - The number of TCP streams to open. For a value of 1, Fastsync - behaves very similar to 'netcat'. With higher values, - Fastsync leverages the fact that file chunks don't need to - arrive in order to avoid the head-of-line blocking of a - single connection. You should experiment to find the best - value, going from 1 to 4 is usually helpful, going from 16 - to 32 is probably overkill. + The address (IP and port) that the sender is listening on. + E.g. '100.71.154.83:7999'. + + The number of TCP streams to open. For a value of 1, Fastsync + behaves very similar to 'netcat'. With higher values, + Fastsync leverages the fact that file chunks don't need to + arrive in order to avoid the head-of-line blocking of a + single connection. You should experiment to find the best + value, going from 1 to 4 is usually helpful, going from 16 + to 32 is probably overkill. "; const WIRE_PROTO_VERSION: u16 = 1; @@ -133,12 +140,25 @@ fn main() { match args.first().map(|s| &s[..]) { Some("send") if args.len() >= 3 => { let addr = &args[1]; - let fnames = &args[2..]; + let max_bandwidth = match args[2].as_str() { + "--max-bandwidth-mbps" => Some( + args[3] + .parse::() + .expect("Invalid number for --max-bandwidth-mbps"), + ), + _ => None, + }; + let fnames = if max_bandwidth.is_some() { + &args[4..] + } else { + &args[2..] + }; main_send( SocketAddr::from_str(addr).expect("Invalid send address"), fnames, WIRE_PROTO_VERSION, events_tx, + max_bandwidth, ) .expect("Failed to send."); } @@ -179,7 +199,7 @@ struct SendState { enum SendResult { Done, - Progress, + Progress { bytes_sent: u64 }, } /// Metadata about a chunk of data that follows. @@ -234,6 +254,7 @@ impl SendState { let mut off = offset as i64; let out_fd = out.as_raw_fd(); let in_fd = self.in_file.as_raw_fd(); + let mut total_written: u64 = 0; while off < end { let count = (end - off) as usize; // Note, sendfile advances the offset by the number of bytes written @@ -242,9 +263,12 @@ impl SendState { if n_written < 0 { return Err(Error::last_os_error()); } + total_written += n_written as u64; } - Ok(SendResult::Progress) + Ok(SendResult::Progress { + bytes_sent: total_written, + }) } } @@ -275,6 +299,7 @@ fn main_send( fnames: &[String], protocol_version: u16, sender_events: std::sync::mpsc::Sender, + max_bandwidth_mbps: Option, ) -> Result<()> { let mut plan = TransferPlan { proto_version: protocol_version, @@ -314,6 +339,13 @@ fn main_send( )) .expect("Listener should not exit before the sender."); + let limiter_mutex = Arc::new(Mutex::new(Option::::None)); + + if let Some(mbps) = max_bandwidth_mbps { + let ratelimiter = RateLimiter::new(mbps, MAX_CHUNK_LEN, Instant::now()); + _ = limiter_mutex.lock().unwrap().insert(ratelimiter); + } + loop { let (mut stream, addr) = listener.accept()?; println!("Accepted connection from {addr}."); @@ -337,15 +369,34 @@ fn main_send( } let state_clone = state_arc.clone(); + + let limiter_mutex_2 = limiter_mutex.clone(); let push_thread = std::thread::spawn(move || { let start_time = Instant::now(); // All the threads iterate through all the files one by one, so all // the threads collaborate on sending the first one, then the second // one, etc. + 'files: for file in state_clone.iter() { 'chunks: loop { + let mut limiter_mutex = limiter_mutex_2.lock().unwrap(); + let mut opt_ratelimiter = limiter_mutex.as_mut(); + if let Some(ref mut ratelimiter) = opt_ratelimiter { + let to_wait = + ratelimiter.time_until_bytes_available(Instant::now(), MAX_CHUNK_LEN); + // if to_wait is None, we've requested to send more than the bucket's max + // capacity, which is a programming error. Crash the program. + std::thread::sleep(to_wait.unwrap()); + } match file.send_one(start_time, &mut stream) { - Ok(SendResult::Progress) => continue 'chunks, + Ok(SendResult::Progress { + bytes_sent: bytes_written, + }) => { + if let Some(ref mut ratelimiter) = opt_ratelimiter { + ratelimiter.consume_bytes(Instant::now(), bytes_written); + } + continue 'chunks; + } Ok(SendResult::Done) => continue 'files, Err(err) => panic!("Failed to send: {err}"), } @@ -601,6 +652,7 @@ mod tests { &["a-file".into()], 1, events_tx, + None, ) .unwrap(); }); @@ -627,6 +679,7 @@ mod tests { &["a-file".into()], 2, events_tx, + None, ) .unwrap(); }); diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs new file mode 100644 index 0000000..98b118e --- /dev/null +++ b/src/ratelimiter.rs @@ -0,0 +1,145 @@ +use std::time::{Duration, Instant}; + +#[derive(Debug)] +pub struct RateLimiter { + capacity_bytes: u64, + available_bytes: u64, + bytes_per_second: u64, + last_update: Instant, +} + +impl RateLimiter { + pub fn new(mbps_target: u64, capacity_bytes: u64, now: Instant) -> Self { + let bps_target = mbps_target * 1_000_000; + RateLimiter { + capacity_bytes, + available_bytes: capacity_bytes, + bytes_per_second: bps_target, + last_update: now, + } + } + + pub fn bytes_available(&self, now: Instant) -> u64 { + let elapsed = now - self.last_update; + let new_bytes = elapsed.as_secs_f32() * self.bytes_per_second as f32; + std::cmp::min(self.available_bytes + new_bytes as u64, self.capacity_bytes) + } + + pub fn consume_bytes(&mut self, now: Instant, amount: u64) { + let elapsed = now - self.last_update; + let new_bytes = (elapsed.as_secs_f32() * self.bytes_per_second as f32) as u64; + self.available_bytes += new_bytes; + self.available_bytes = std::cmp::min(self.available_bytes, self.capacity_bytes); + assert!(self.available_bytes >= amount); + self.available_bytes -= amount; + self.last_update = now; + } + + pub fn time_until_bytes_available(&self, now: Instant, amount: u64) -> Option { + if amount > self.capacity_bytes { + return None; + } + let elapsed = now - self.last_update; + let new_bytes = (elapsed.as_secs_f32() * self.bytes_per_second as f32) as u64; + let total_bytes = self.available_bytes + new_bytes; + if self.available_bytes + new_bytes > amount { + return Some(Duration::from_secs(0)); + } + + let needed = amount - total_bytes; + Some(Duration::from_secs_f32( + needed as f32 / self.bytes_per_second as f32, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_initial_state() { + let start = Instant::now(); + let rl = RateLimiter::new(10, 10_000_000, start); + + assert_eq!(rl.bytes_per_second, 10_000_000); + assert_eq!(rl.bytes_available(start), rl.bytes_per_second); + } + + #[test] + fn test_bytes_available_after_one_second() { + let start = Instant::now(); + let rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + assert_eq!(rl.bytes_available(now), 10_000_000); + } + + #[test] + fn test_consume_bytes() { + let start = Instant::now(); + let mut rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + assert_eq!(rl.bytes_available(now), 10_000_000); + rl.consume_bytes(now, 4_000_000); + assert_eq!(rl.available_bytes, 6_000_000); + } + + #[test] + fn test_bytes_available_capped_at_max() { + let start = Instant::now(); + let mut rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + rl.consume_bytes(now, 5_000_000); + + let now = now + Duration::from_millis(500); // 0.5 seconds later + assert_eq!(rl.bytes_available(now), 10_000_000); // Should be capped at max + + let now = now + Duration::from_millis(500); // 0.5 seconds later + assert_eq!(rl.bytes_available(now), 10_000_000); // Should be capped at max + } + + #[test] + fn test_time_until_bytes_available() { + let start = Instant::now(); + let mut rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + rl.consume_bytes(now, 9_000_000); + assert_eq!(rl.available_bytes, 1_000_000); + + let wait_time = rl.time_until_bytes_available(now, 9_000_000).unwrap(); + // at 10MB/s, 800ms for 800KB + assert!(wait_time > Duration::from_millis(799) && wait_time < Duration::from_millis(801)); + } + + #[test] + fn test_immediate_availability() { + let start = Instant::now(); + let mut rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + rl.consume_bytes(now, 9_000_000); + + assert_eq!( + rl.time_until_bytes_available(now, 1_000_000).unwrap(), + Duration::from_secs(0) + ); + } + + #[test] + fn test_wait_time_beyond_bucket_capacity() { + let start = Instant::now(); + let mut rl = RateLimiter::new(10, 10_000_000, start); + + let now = start + Duration::from_secs(1); + rl.consume_bytes(now, 9_000_000); + + // this is not true, there will never be 20M available in the bucket. + // not sure if this case should throw when asking for > bps + let wait_time = rl.time_until_bytes_available(now, 20_000_000); + assert!(wait_time.is_none()); + } +}