diff --git a/Cargo.lock b/Cargo.lock index 506173a..e943159 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bitflags" version = "2.6.0" @@ -72,6 +78,7 @@ version = "1.0.0" dependencies = [ "borsh", "libc", + "parking_lot", "tempfile", "walkdir", ] @@ -104,6 +111,16 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "memchr" version = "2.7.1" @@ -116,6 +133,29 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "proc-macro-crate" version = "3.1.0" @@ -166,6 +206,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + [[package]] name = "rustix" version = "0.38.34" @@ -188,6 +237,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "syn" version = "2.0.49" diff --git a/Cargo.toml b/Cargo.toml index 58b59f5..af1c76f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ publish = false libc = "0.2.153" borsh = { version = "1.3.1", features = ["derive"] } walkdir = "2.5.0" +parking_lot = "0.12.3" [profile.dev] panic = "abort" diff --git a/src/main.rs b/src/main.rs index ed841b0..dc7dd87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,9 +12,8 @@ use std::fs::File; use std::io::{Error, ErrorKind, Read, Result, Write}; use std::net::{SocketAddr, TcpStream}; use std::os::fd::AsRawFd; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::str::FromStr; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{mpsc, Arc, Mutex}; use std::time::Instant; use walkdir::WalkDir; @@ -56,7 +55,7 @@ Receiver options: to 32 is probably overkill. "; -const WIRE_PROTO_VERSION: u16 = 1; +const WIRE_PROTO_VERSION: u16 = 2; const MAX_CHUNK_LEN: u64 = 4096 * 64; /// Metadata about all the files we want to transfer. @@ -104,7 +103,7 @@ impl TransferPlan { fn assert_paths_relative(&self) { for file in &self.files { assert!( - !file.name.starts_with("/"), + !file.name.starts_with('/'), "Transferring files with an absolute path name is not allowed.", ); } @@ -113,11 +112,11 @@ impl TransferPlan { /// The index of a file in the transfer plan. #[derive(BorshDeserialize, BorshSerialize, Copy, Clone, Debug, Eq, Hash, PartialEq)] -struct FileId(u16); +struct FileId(u32); impl FileId { fn from_usize(i: usize) -> FileId { - assert!(i < u16::MAX as usize, "Can transfer at most 2^16 files."); + assert!(i < u32::MAX as usize, "Can transfer at most 2^32 files."); FileId(i as _) } } @@ -190,21 +189,27 @@ fn print_progress(offset: u64, len: u64, start_time: Instant) { ); } +enum SendStateInner { + Pending { fname: PathBuf }, + InProgress { file: File, offset: u64 }, + Done, +} + struct SendState { id: FileId, len: u64, - offset: AtomicU64, - in_file: File, + state: parking_lot::Mutex, } enum SendResult { Done, + FileVanished, Progress { bytes_sent: u64 }, } /// Metadata about a chunk of data that follows. /// -/// The Borsh-generated representation of this is zero-overhead (14 bytes). +/// The Borsh-generated representation of this is zero-overhead (16 bytes). #[derive(BorshDeserialize, BorshSerialize, Debug)] struct ChunkHeader { /// Which file is the chunk from? @@ -218,8 +223,8 @@ struct ChunkHeader { } impl ChunkHeader { - fn to_bytes(&self) -> [u8; 14] { - let mut buffer = [0_u8; 14]; + fn to_bytes(&self) -> [u8; 16] { + let mut buffer = [0_u8; 16]; let mut cursor = std::io::Cursor::new(&mut buffer[..]); self.serialize(&mut cursor) .expect("Writing to memory never fails."); @@ -229,12 +234,49 @@ impl ChunkHeader { impl SendState { pub fn send_one(&self, start_time: Instant, out: &mut TcpStream) -> Result { - let offset = self.offset.fetch_add(MAX_CHUNK_LEN, Ordering::SeqCst); + // By deferring the opening of the file descriptor to this point, + // we effectively limit the amount of open files to the amount of send threads. + // However, this now introduces the possibility of files getting deleted between + // getting listed and their turn for transfer. + // A vanishing file is expected, it is not a transfer-terminating event. + let mut state = self.state.lock(); + let (offset, in_fd) = match *state { + SendStateInner::Pending { ref fname } => { + let res = match std::fs::File::open(fname) { + Ok(f) => f, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Ok(SendResult::FileVanished) + } + Err(e) => return Err(e), + }; + let fd = res.as_raw_fd(); + *state = SendStateInner::InProgress { + file: res, + offset: 0, + }; + (0, fd) + } + SendStateInner::Done => { + return Ok(SendResult::Done); + } + SendStateInner::InProgress { + ref file, + ref mut offset, + } => { + *offset += MAX_CHUNK_LEN; + (*offset, file.as_raw_fd()) + } + }; + let end = self.len.min(offset + MAX_CHUNK_LEN); if offset >= self.len || offset >= end { + *state = SendStateInner::Done; return Ok(SendResult::Done); } + // Drop the lock while sending -- so that multiple threads can + // send data from the same file at once + std::mem::drop(state); print_progress(offset, self.len, start_time); @@ -252,9 +294,8 @@ impl SendState { let end = end as i64; let mut off = offset as i64; - let out_fd = out.as_raw_fd(); - let in_fd = self.in_file.as_raw_fd(); let mut total_written: u64 = 0; + let out_fd = out.as_raw_fd(); while off < end { let count = (end - off) as usize; // Note, sendfile advances the offset by the number of bytes written @@ -308,8 +349,7 @@ fn main_send( let mut send_states = Vec::new(); for (i, fname) in all_filenames_from_path_names(fnames)?.iter().enumerate() { - let file = std::fs::File::open(&fname)?; - let metadata = file.metadata()?; + let metadata = std::fs::metadata(fname)?; let file_plan = FilePlan { name: fname.clone(), len: metadata.len(), @@ -317,8 +357,9 @@ fn main_send( let state = SendState { id: FileId::from_usize(i), len: metadata.len(), - offset: AtomicU64::new(0), - in_file: file, + state: parking_lot::Mutex::new(SendStateInner::Pending { + fname: fname.into(), + }), }; plan.files.push(file_plan); send_states.push(state); @@ -363,7 +404,7 @@ fn main_send( // Stop the listener, don't send anything over our new connection. let is_done = state_arc .iter() - .all(|f| f.offset.load(Ordering::SeqCst) >= f.len); + .all(|f| matches!(*f.state.lock(), SendStateInner::Done)); if is_done { break; } @@ -389,6 +430,9 @@ fn main_send( std::thread::sleep(to_wait.unwrap()); } match file.send_one(start_time, &mut stream) { + Ok(SendResult::FileVanished) => { + println!("File {:?} vanished", file.id); + } Ok(SendResult::Progress { bytes_sent: bytes_written, }) => { @@ -427,8 +471,6 @@ struct Chunk { struct FileReceiver { fname: String, - /// The file we’re writing to, if we have started writing. - /// /// We don’t open the file immediately so we don’t create a zero-sized file /// when a transfer fails. We only open the file after we have at least some /// data for it. @@ -488,7 +530,11 @@ impl FileReceiver { self.offset += chunk.data.len() as u64; } - self.out_file = Some(out_file); + if self.offset < self.total_len { + self.out_file = Some(out_file); + // Only keep the file open as long as there is more to write + } + Ok(()) } } @@ -571,7 +617,7 @@ fn main_recv( // Read a chunk header. If we hit EOF, that is not an error, it // means that the sender has nothing more to send so we can just // exit here. - let mut buf = [0u8; 14]; + let mut buf = [0u8; 16]; match stream.read_exact(&mut buf) { Ok(..) => {} Err(err) if err.kind() == ErrorKind::UnexpectedEof => break, @@ -637,10 +683,12 @@ fn main_recv( #[cfg(test)] mod tests { use super::*; + use std::env; use std::{ net::{IpAddr, Ipv4Addr}, thread, }; + use tempfile::TempDir; #[test] fn test_accepts_valid_protocol() { @@ -717,4 +765,79 @@ mod tests { ["0", "a/1", "a/b/2"].map(|f| base_path.join(f).to_str().unwrap().to_owned()) ); } + + #[test] + fn test_sends_large_file() { + let (events_tx, events_rx) = std::sync::mpsc::channel::(); + env::set_current_dir("/tmp/").unwrap(); + let cwd = env::current_dir().unwrap(); + thread::spawn(|| { + let td = TempDir::new_in(".").unwrap(); + let tmp_path = td.path().strip_prefix(cwd).unwrap(); + let path = tmp_path.join("large"); + let fnames = &[path.clone().into_os_string().into_string().unwrap()]; + + { + let mut f = std::fs::File::create(path).unwrap(); + f.write_all(&vec![0u8; MAX_CHUNK_LEN as usize * 100]) + .unwrap(); + } + + main_send( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), + fnames, + 1, + events_tx, + None, + ) + .unwrap(); + }); + match events_rx.recv().unwrap() { + SenderEvent::Listening(port) => { + main_recv( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port), + "1", + WriteMode::Force, + 1, + ) + .unwrap(); + } + } + } + #[test] + fn test_sends_20_thousand_files() { + let (events_tx, events_rx) = std::sync::mpsc::channel::(); + env::set_current_dir("/tmp/").unwrap(); + let cwd = env::current_dir().unwrap(); + thread::spawn(|| { + let td = TempDir::new_in(".").unwrap(); + let tmp_path = td.path().strip_prefix(cwd).unwrap(); + let mut fnames = Vec::new(); + for i in 0..20_000 { + let path = tmp_path.join(i.to_string()); + fnames.push(path.clone().into_os_string().into_string().unwrap()); + let mut f = std::fs::File::create(path).unwrap(); + f.write(&[1, 2, 3]).unwrap(); + } + main_send( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), + &fnames, + 1, + events_tx, + None, + ) + .unwrap(); + }); + match events_rx.recv().unwrap() { + SenderEvent::Listening(port) => { + main_recv( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port), + "1", + WriteMode::Force, + 1, + ) + .unwrap(); + } + } + } }