diff --git a/examples/get_metadata/Cargo.toml b/examples/get_metadata/Cargo.toml index 7c820e45a..bcd51bf63 100644 --- a/examples/get_metadata/Cargo.toml +++ b/examples/get_metadata/Cargo.toml @@ -18,6 +18,7 @@ dht = { path = "../../packages/dht" } handshake = { path = "../../packages/handshake" } peer = { path = "../../packages/peer" } select = { path = "../../packages/select" } +metainfo = {path ="../../packages/metainfo" } clap = "4" hex = "0" diff --git a/examples/get_metadata/src/main.rs b/examples/get_metadata/src/main.rs index 6c3e762ee..10b963e5c 100644 --- a/examples/get_metadata/src/main.rs +++ b/examples/get_metadata/src/main.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use std::fs::File; use std::io::Write as _; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Once}; use std::time::Duration; use clap::{Arg, ArgMatches, Command}; @@ -15,6 +15,7 @@ use handshake::{ DiscoveryInfo, Extension, Extensions, HandshakerBuilder, HandshakerConfig, InfoHash, InitiateMessage, PeerId, Protocol, }; use hex::FromHex; +use metainfo::Metainfo; use peer::messages::builders::ExtendedMessageBuilder; use peer::messages::{BitsExtensionMessage, PeerExtensionProtocolMessage, PeerWireProtocolMessage}; use peer::protocols::{NullProtocol, PeerExtensionProtocol, PeerWireProtocol}; @@ -23,7 +24,11 @@ use peer::{ }; use select::discovery::{IDiscoveryMessage, ODiscoveryMessage, UtMetadataModule}; use select::{ControlMessage, IExtendedMessage, IUberMessage, OUberMessage, UberModuleBuilder}; +use tokio::signal; use tokio_util::codec::Framed; +use tracing::level_filters::LevelFilter; + +pub static INIT: Once = Once::new(); // Legacy Handshaker, when bip_dht is migrated, it will accept S directly struct LegacyHandshaker { @@ -103,9 +108,39 @@ fn extract_arguments(matches: &ArgMatches) -> (String, String) { (hash, output) } +pub fn tracing_stdout_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stdout); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +async fn ctrl_c() { + signal::ctrl_c().await.expect("failed to listen for event"); + tracing::warn!("Ctrl-C received, shutting down..."); +} + +enum SendUber { + Finished(Result<(), select::error::Error>), + Interrupted, +} + +enum MainDht { + Finished(Box), + Interrupted, +} + #[allow(clippy::too_many_lines)] #[tokio::main] async fn main() { + INIT.call_once(|| { + tracing_stdout_init(LevelFilter::TRACE); + }); + // Parse command-line arguments let matches = parse_arguments(); let (hash, output) = extract_arguments(&matches); @@ -170,12 +205,33 @@ async fn main() { .into_parts(); // Tell the uber module we want to download metainfo for the given hash - uber_send + let send_to_uber = uber_send .send(IUberMessage::Discovery(Box::new(IDiscoveryMessage::DownloadMetainfo( info_hash, )))) - .await - .expect("it should send the instruction"); + .boxed(); + + // Await either the sending to uber or the Ctrl-C signal + let send_to_uber = tokio::select! { + res = send_to_uber => SendUber::Finished(res), + () = ctrl_c() => SendUber::Interrupted, + }; + + let () = match send_to_uber { + SendUber::Finished(Ok(())) => (), + + SendUber::Finished(Err(e)) => { + tracing::warn!("send to uber failed with error: {e}"); + tasks.shutdown().await; + return; + } + + SendUber::Interrupted => { + tracing::warn!("setup was canceled..."); + tasks.shutdown().await; + return; + } + }; let timer = futures::stream::unfold(tokio::time::interval(Duration::from_millis(100)), |mut interval| async move { interval.tick().await; @@ -192,11 +248,11 @@ async fn main() { let message = if let Either::Left(message) = item { match message { Ok(PeerManagerOutputMessage::PeerAdded(info)) => { - println!("Connected To Peer: {info:?}"); + tracing::info!("Connected To Peer: {info:?}"); IUberMessage::Control(Box::new(ControlMessage::PeerConnected(info))) } Ok(PeerManagerOutputMessage::PeerRemoved(info)) => { - println!("We Removed Peer {info:?} From The Peer Manager"); + tracing::info!("We Removed Peer {info:?} From The Peer Manager"); IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) } Ok(PeerManagerOutputMessage::SentMessage(_, _)) => todo!(), @@ -216,7 +272,7 @@ async fn main() { _ => unimplemented!(), }, Ok(PeerManagerOutputMessage::PeerDisconnect(info)) => { - println!("Peer {info:?} Disconnected From Us"); + tracing::info!("Peer {info:?} Disconnected From Us"); IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) } Err(e) => { @@ -227,7 +283,7 @@ async fn main() { | PeerManagerOutputError::PeerDisconnectedAndMissing(info) => info, }; - println!("Peer {info:?} Disconnected With Error: {e:?}"); + tracing::info!("Peer {info:?} Disconnected With Error: {e:?}"); IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) } } @@ -241,25 +297,45 @@ async fn main() { // Setup the dht which will be the only peer discovery service we use in this example let legacy_handshaker = LegacyHandshaker::new(handshaker_send); - let dht = DhtBuilder::with_router(Router::uTorrent) - .set_read_only(false) - .start_mainline(legacy_handshaker) - .await - .expect("it should start the dht mainline"); - println!("Bootstrapping Dht..."); - while let Some(message) = dht.events().await.next().await { - if let DhtEvent::BootstrapCompleted = message { - break; + let main_dht = async move { + let dht = DhtBuilder::with_router(Router::uTorrent) + .set_read_only(false) + .start_mainline(legacy_handshaker) + .await + .expect("it should start the dht mainline"); + + tracing::info!("Bootstrapping Dht..."); + while let Some(message) = dht.events().await.next().await { + if let DhtEvent::BootstrapCompleted = message { + break; + } + } + tracing::info!("Bootstrap Complete..."); + + dht.search(info_hash, true).await; + + loop { + if let Some(Ok(OUberMessage::Discovery(ODiscoveryMessage::DownloadedMetainfo(metainfo)))) = uber_recv.next().await { + break metainfo; + } } } - println!("Bootstrap Complete..."); + .boxed(); + + // Await either the sending to uber or the Ctrl-C signal + let main_dht = tokio::select! { + res = main_dht => MainDht::Finished(Box::new(res)), + () = ctrl_c() => MainDht::Interrupted, + }; - dht.search(info_hash, true).await; + let metainfo = match main_dht { + MainDht::Finished(metainfo) => metainfo, - let metainfo = loop { - if let Some(Ok(OUberMessage::Discovery(ODiscoveryMessage::DownloadedMetainfo(metainfo)))) = uber_recv.next().await { - break metainfo; + MainDht::Interrupted => { + tracing::warn!("setup was canceled..."); + tasks.shutdown().await; + return; } }; diff --git a/examples/simple_torrent/src/main.rs b/examples/simple_torrent/src/main.rs index 87295d9a0..eb5a5b0a1 100644 --- a/examples/simple_torrent/src/main.rs +++ b/examples/simple_torrent/src/main.rs @@ -79,7 +79,7 @@ pub fn tracing_stdout_init(filter: LevelFilter) { async fn ctrl_c() { signal::ctrl_c().await.expect("failed to listen for event"); - println!("Ctrl-C received, shutting down..."); + tracing::warn!("Ctrl-C received, shutting down..."); } #[tokio::main] diff --git a/examples/simple_torrent/src/main_old.rs b/examples/simple_torrent/src/main_old.rs deleted file mode 100644 index 8a4e23da4..000000000 --- a/examples/simple_torrent/src/main_old.rs +++ /dev/null @@ -1,596 +0,0 @@ -use std::cell::RefCell; -use std::cmp; -use std::collections::HashMap; -use std::fs::File; -use std::io::Read; -use std::rc::Rc; - -use clap::clap_app; -use disk::fs::NativeFileSystem; -use disk::fs_cache::FileHandleCache; -use disk::{Block, BlockMetadata, BlockMut, DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::{Either, Loop}; -use futures::sync::mpsc; -use futures::{future, stream, Future, Sink, Stream}; -use handshake::transports::TcpTransport; -//use bip_dht::{DhtBuilder, Handshaker, Router}; -use handshake::{Extensions, HandshakerBuilder, HandshakerConfig, InitiateMessage, PeerId, Protocol}; -use metainfo::{Info, Metainfo}; -use peer::messages::{BitFieldMessage, HaveMessage, PeerWireProtocolMessage, PieceMessage, RequestMessage}; -use peer::protocols::{NullProtocol, PeerWireProtocol}; -use peer::{IPeerManagerMessage, OPeerManagerMessage, PeerInfo, PeerManagerBuilder, PeerProtocolCodec}; -use tokio_core::reactor::Core; - -/* - Things this example doesn't do, because of the lack of bip_select: - * Logic for piece selection is not abstracted (and is pretty bad) - * We will unconditionally upload pieces to a peer (regardless whether or not they were choked) - * We don't add an info hash filter to bip_handshake after we have as many peers as we need/want - * We don't do any banning of malicious peers - - Things the example doesn't do, unrelated to bip_select: - * Matching peers up to disk requests isn't as good as it could be - * doesn't use a shared BytesMut for servicing piece requests - * Good logging -*/ - -/* -// Legacy Handshaker, when bip_dht is migrated, it will accept S directly -struct LegacyHandshaker { - port: u16, - id: PeerId, - sender: Wait -} - -impl LegacyHandshaker where S: DiscoveryInfo + Sink { - pub fn new(sink: S) -> LegacyHandshaker { - LegacyHandshaker{ port: sink.port(), id: sink.peer_id(), sender: sink.wait() } - } -} - -impl Handshaker for LegacyHandshaker where S: Sink + Send, S::SinkError: Debug { - type MetadataEnvelope = (); - - fn id(&self) -> PeerId { self.id } - - fn port(&self) -> u16 { self.port } - - fn connect(&mut self, _expected: Option, hash: InfoHash, addr: SocketAddr) { - self.sender.send(InitiateMessage::new(Protocol::BitTorrent, hash, addr)); - } - - fn metadata(&mut self, _data: ()) { () } -} -*/ - -// How many requests can be in flight at once. -const MAX_PENDING_BLOCKS: usize = 50; - -// Some enum to store our selection state updates -#[allow(dead_code)] -#[derive(Debug)] -enum SelectState { - Choke(PeerInfo), - UnChoke(PeerInfo), - Interested(PeerInfo), - UnInterested(PeerInfo), - Have(PeerInfo, HaveMessage), - BitField(PeerInfo, BitFieldMessage), - NewPeer(PeerInfo), - RemovedPeer(PeerInfo), - BlockProcessed, - GoodPiece(u64), - BadPiece(u64), - TorrentSynced, - TorrentAdded, -} - -#[allow(clippy::too_many_lines)] -fn main() { - // Command line argument parsing - let matches = clap_app!(myapp => - - (version: "1.0") - (author: "Andrew ") - (about: "Simple torrent downloading") - (@arg file: -f +required +takes_value "Location of the torrent file") - (@arg dir: -d +takes_value "Download directory to use") - (@arg peer: -p +takes_value "Single peer to connect to of the form addr:port") - ) - .get_matches(); - let file = matches.value_of("file").unwrap(); - let dir = matches.value_of("dir").unwrap(); - let peer_addr = matches.value_of("peer").unwrap().parse().unwrap(); - - // Load in our torrent file - let mut metainfo_bytes = Vec::new(); - File::open(file).unwrap().read_to_end(&mut metainfo_bytes).unwrap(); - - // Parse out our torrent file - let metainfo = Metainfo::from_bytes(metainfo_bytes).unwrap(); - let info_hash = metainfo.info().info_hash(); - - // Create our main "core" event loop - let mut core = Core::new().unwrap(); - - // Create a disk manager to handle storing/loading blocks (we add in a file handle cache - // to avoid anti virus causing slow file opens/closes, will cache up to 100 file handles) - let (disk_manager_send, disk_manager_recv) = DiskManagerBuilder::new() - // Reducing our sink and stream capacities allow us to constrain memory usage - // (though for spiky downloads, this could effectively throttle us, which is ok too.) - .with_sink_buffer_capacity(1) - .with_stream_buffer_capacity(0) - .build(FileHandleCache::new(NativeFileSystem::with_directory(dir), 100)) - .into_parts(); - - // Create a handshaker that can initiate connections with peers - let (handshaker_send, handshaker_recv) = HandshakerBuilder::new() - .with_peer_id(PeerId::from_hash("-BI0000-000000000000".as_bytes()).unwrap()) - // We would ideally add a filter to the handshaker to block - // peers when we have enough of them for a given hash, but - // since this is a minimal example, we will rely on peer - // manager backpressure (handshaker -> peer manager will - // block when we reach our max peers). Setting these to low - // values so we don't have more than 2 unused tcp connections. - .with_config(HandshakerConfig::default().with_wait_buffer_size(0).with_done_buffer_size(0)) - .build::(TcpTransport, &core.handle()) // Will handshake over TCP (could swap this for UTP in the future) - .unwrap() - .into_parts(); - // Create a peer manager that will hold our peers and heartbeat/send messages to them - let (peer_manager_send, peer_manager_recv) = PeerManagerBuilder::new() - // Similar to the disk manager sink and stream capacities, we can constrain those - // for the peer manager as well. - .with_sink_buffer_capacity(0) - .with_stream_buffer_capacity(0) - .build(core.handle()) - .into_parts(); - - // Hook up a future that feeds incoming (handshaken) peers over to the peer manager - let map_peer_manager_send = peer_manager_send.clone().sink_map_err(|_| ()); - core.handle().spawn( - handshaker_recv - .map_err(|()| ()) - .map(|complete_msg| { - // Our handshaker finished handshaking some peer, get - // the peer info as well as the peer itself (socket) - let (_, _, hash, pid, addr, sock) = complete_msg.into_parts(); - // Frame our socket with the peer wire protocol with no extensions (nested null protocol), and a max payload of 24KB - let peer = tokio_codec::Decoder::framed( - PeerProtocolCodec::with_max_payload(PeerWireProtocol::new(NullProtocol::new()), 24 * 1024), - sock, - ); - - // Create our peer identifier used by our peer manager - let peer_info = PeerInfo::new(addr, pid, hash, Extensions::new()); - - // Map to a message that can be fed to our peer manager - IPeerManagerMessage::AddPeer(peer_info, peer) - }) - .forward(map_peer_manager_send) - .map(|_| ()), - ); - - // Will hold a mapping of BlockMetadata -> Vec to track which peers to send a queued block to - let disk_request_map = Rc::new(RefCell::new(HashMap::new())); - let (select_send, select_recv) = mpsc::channel(50); - - // Map out the errors for these sinks so they match - let map_select_send = select_send.clone().sink_map_err(|_| ()); - let map_disk_manager_send = disk_manager_send.clone().sink_map_err(|()| ()); - - // Hook up a future that receives messages from the peer manager, and forwards request to the disk manager or selection manager (using loop fn - // here because we need to be able to access state, like request_map and a different future combinator wouldn't let us keep it around to access) - core.handle().spawn(future::loop_fn( - ( - peer_manager_recv, - info_hash, - disk_request_map.clone(), - map_select_send, - map_disk_manager_send, - ), - |(peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)| { - peer_manager_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_item, peer_manager_recv)| { - let opt_message = match opt_item.unwrap() { - OPeerManagerMessage::ReceivedMessage(info, message) => { - match message { - PeerWireProtocolMessage::Choke => Some(Either::A(SelectState::Choke(info))), - PeerWireProtocolMessage::UnChoke => Some(Either::A(SelectState::UnChoke(info))), - PeerWireProtocolMessage::Interested => Some(Either::A(SelectState::Interested(info))), - PeerWireProtocolMessage::UnInterested => Some(Either::A(SelectState::UnInterested(info))), - PeerWireProtocolMessage::Have(have) => Some(Either::A(SelectState::Have(info, have))), - PeerWireProtocolMessage::BitField(bitfield) => { - Some(Either::A(SelectState::BitField(info, bitfield))) - } - PeerWireProtocolMessage::Request(request) => { - let block_metadata = BlockMetadata::new( - info_hash, - u64::from(request.piece_index()), - u64::from(request.block_offset()), - request.block_length(), - ); - let mut request_map_mut = disk_request_map.borrow_mut(); - - // Add the block metadata to our request map, and add the peer as an entry there - let block_entry = request_map_mut.entry(block_metadata); - let peers_requested = block_entry.or_insert(Vec::new()); - - peers_requested.push(info); - - Some(Either::B(IDiskMessage::LoadBlock(BlockMut::new( - block_metadata, - vec![0u8; block_metadata.block_length()].into(), - )))) - } - PeerWireProtocolMessage::Piece(piece) => { - let block_metadata = BlockMetadata::new( - info_hash, - u64::from(piece.piece_index()), - u64::from(piece.block_offset()), - piece.block_length(), - ); - - // Peer sent us a block, send it over to the disk manager to be processed - Some(Either::B(IDiskMessage::ProcessBlock(Block::new( - block_metadata, - piece.block(), - )))) - } - _ => None, - } - } - OPeerManagerMessage::PeerAdded(info) => Some(Either::A(SelectState::NewPeer(info))), - OPeerManagerMessage::SentMessage(_, _) => None, - OPeerManagerMessage::PeerRemoved(info) => { - println!("We Removed Peer {info:?} From The Peer Manager"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - OPeerManagerMessage::PeerDisconnect(info) => { - println!("Peer {info:?} Disconnected From Us"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - OPeerManagerMessage::PeerError(info, error) => { - println!("Peer {info:?} Disconnected With Error: {error:?}"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - }; - - // Could optimize out the box, but for the example, this is cleaner and shorter - let result_future: Box, Error = ()>> = match opt_message { - Some(Either::A(select_message)) => Box::new(select_send.send(select_message).map(move |select_send| { - Loop::Continue((peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)) - })), - Some(Either::B(disk_message)) => { - Box::new(disk_manager_send.send(disk_message).map(move |disk_manager_send| { - Loop::Continue((peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)) - })) - } - None => Box::new(future::ok(Loop::Continue(( - peer_manager_recv, - info_hash, - disk_request_map, - select_send, - disk_manager_send, - )))), - }; - - result_future - }) - }, - )); - - // Map out the errors for these sinks so they match - let map_select_send = select_send.clone().sink_map_err(|_| ()); - let map_peer_manager_send = peer_manager_send.clone().sink_map_err(|_| ()); - - // Hook up a future that receives from the disk manager, and forwards to the peer manager or select manager - core.handle().spawn(future::loop_fn( - ( - disk_manager_recv, - disk_request_map.clone(), - map_select_send, - map_peer_manager_send, - ), - |(disk_manager_recv, disk_request_map, select_send, peer_manager_send)| { - disk_manager_recv - .into_future() - .map_err(|_| ()) - .and_then(|(opt_item, disk_manager_recv)| { - let opt_message = match opt_item.unwrap() { - ODiskMessage::BlockLoaded(block) => { - let (metadata, block) = block.into_parts(); - - // Lookup the peer info given the block metadata - let mut request_map_mut = disk_request_map.borrow_mut(); - let peer_list = request_map_mut.get_mut(&metadata).unwrap(); - let peer_info = peer_list.remove(1); - - // Pack up our block into a peer wire protocol message and send it off to the peer - #[allow(clippy::cast_possible_truncation)] - let piece = - PieceMessage::new(metadata.piece_index() as u32, metadata.block_offset() as u32, block.freeze()); - let pwp_message = PeerWireProtocolMessage::Piece(piece); - - Some(Either::B(IPeerManagerMessage::SendMessage(peer_info, 0, pwp_message))) - } - ODiskMessage::TorrentAdded(_) => Some(Either::A(SelectState::TorrentAdded)), - ODiskMessage::TorrentSynced(_) => Some(Either::A(SelectState::TorrentSynced)), - ODiskMessage::FoundGoodPiece(_, index) => Some(Either::A(SelectState::GoodPiece(index))), - ODiskMessage::FoundBadPiece(_, index) => Some(Either::A(SelectState::BadPiece(index))), - ODiskMessage::BlockProcessed(_) => Some(Either::A(SelectState::BlockProcessed)), - _ => None, - }; - - // Could optimize out the box, but for the example, this is cleaner and shorter - let result_future: Box, Error = ()>> = match opt_message { - Some(Either::A(select_message)) => Box::new(select_send.send(select_message).map(|select_send| { - Loop::Continue((disk_manager_recv, disk_request_map, select_send, peer_manager_send)) - })), - Some(Either::B(peer_message)) => { - Box::new(peer_manager_send.send(peer_message).map(|peer_manager_send| { - Loop::Continue((disk_manager_recv, disk_request_map, select_send, peer_manager_send)) - })) - } - None => Box::new(future::ok(Loop::Continue(( - disk_manager_recv, - disk_request_map, - select_send, - peer_manager_send, - )))), - }; - - result_future - }) - }, - )); - - // Generate data structure to track the requests we need to make, the requests that have been fulfilled, and an active peers list - let piece_requests = generate_requests(metainfo.info(), 16 * 1024); - - // Have our disk manager allocate space for our torrent and start tracking it - core.run(disk_manager_send.send(IDiskMessage::AddTorrent(metainfo.clone()))) - .unwrap(); - - // For any pieces we already have on the file system (and are marked as good), we will be removing them from our requests map - let (select_recv, piece_requests, cur_pieces) = core - .run(future::loop_fn( - (select_recv, piece_requests, 0), - |(select_recv, mut piece_requests, cur_pieces)| { - select_recv - .into_future() - .map(move |(opt_item, select_recv)| { - match opt_item.unwrap() { - // Disk manager identified a good piece already downloaded - SelectState::GoodPiece(index) => { - piece_requests.retain(|req| u64::from(req.piece_index()) != index); - Loop::Continue((select_recv, piece_requests, cur_pieces + 1)) - } - // Disk manager is finished identifying good pieces, torrent has been added - SelectState::TorrentAdded => Loop::Break((select_recv, piece_requests, cur_pieces)), - // Shouldn't be receiving any other messages... - message => panic!("Unexpected Message Received In Selection Receiver: {message:?}"), - } - }) - .map_err(|_| ()) - }, - )) - .unwrap(); - - /* - // Setup the dht which will be the only peer discovery service we use in this example - let legacy_handshaker = LegacyHandshaker::new(handshaker_send); - let dht = DhtBuilder::with_router(Router::uTorrent) - .set_read_only(false) - .start_mainline(legacy_handshaker).unwrap(); - - dht.search(info_hash, true); - */ - - // Send the peer given from the command line over to the handshaker to initiate a connection - core.run( - handshaker_send - .send(InitiateMessage::new(Protocol::BitTorrent, info_hash, peer_addr)) - .map_err(|_| ()), - ) - .unwrap(); - - // Finally, setup our main event loop to drive the tasks we setup earlier - let map_peer_manager_send = peer_manager_send.sink_map_err(|_| ()); - let total_pieces = metainfo.info().pieces().count(); - println!( - "Current Pieces: {}\nTotal Pieces: {}\nRequests Left: {}", - cur_pieces, - total_pieces, - piece_requests.len() - ); - - let result: Result<(), ()> = core.run(future::loop_fn( - ( - select_recv, - map_peer_manager_send, - piece_requests, - None, - false, - 0, - cur_pieces, - total_pieces, - ), - |( - select_recv, - map_peer_manager_send, - mut piece_requests, - mut opt_peer, - mut unchoked, - mut blocks_pending, - mut cur_pieces, - total_pieces, - )| { - select_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_message, select_recv)| { - // Handle the current selection message, decide any control messages we need to send - let send_messages = match opt_message.unwrap() { - SelectState::BlockProcessed => { - // Disk manager let us know a block was processed (one of our requests made it - // from the peer manager, to the disk manager, and this is the acknowledgement) - blocks_pending -= 1; - vec![] - } - SelectState::Choke(_) => { - // Peer choked us, cant be sending any requests to them for now - unchoked = false; - vec![] - } - SelectState::UnChoke(_) => { - // Peer unchoked us, we can continue sending sending requests to them - unchoked = true; - vec![] - } - SelectState::NewPeer(info) => { - // A new peer connected to us, store its contact info (just supported one peer atm), - // and go ahead and express our interest in them, and unchoke them (we can upload to them) - // We don't send a bitfield message (just to keep things simple). - opt_peer = Some(info); - vec![ - IPeerManagerMessage::SendMessage(info, 0, PeerWireProtocolMessage::Interested), - IPeerManagerMessage::SendMessage(info, 0, PeerWireProtocolMessage::UnChoke), - ] - } - SelectState::GoodPiece(piece) => { - // Disk manager has processed enough blocks to make up a piece, and that piece - // was verified to be good (checksummed). Go ahead and increment the number of - // pieces we have. We don't handle bad pieces here (since we deleted our request - // but ideally, we would recreate those requests and resend/blacklist the peer). - cur_pieces += 1; - - if let Some(peer) = opt_peer { - // Send our have message back to the peer - vec![IPeerManagerMessage::SendMessage( - peer, - 0, - PeerWireProtocolMessage::Have(HaveMessage::new(piece.try_into().unwrap())), - )] - } else { - vec![] - } - } - // Decided not to handle these two cases here - SelectState::RemovedPeer(info) => panic!("Peer {info:?} Got Disconnected"), - SelectState::BadPiece(_) => panic!("Peer Gave Us Bad Piece"), - _ => vec![], - }; - - // Need a type annotation of this return type, provide that - let result: Box, Error = ()>> = if cur_pieces == total_pieces { - // We have all of the (unique) pieces required for our torrent - Box::new(future::ok(Loop::Break(()))) - } else if let Some(peer) = opt_peer { - // We have peer contact info, if we are unchoked, see if we can queue up more requests - let next_piece_requests = if unchoked { - let take_blocks = cmp::min(MAX_PENDING_BLOCKS - blocks_pending, piece_requests.len()); - blocks_pending += take_blocks; - - piece_requests - .drain(0..take_blocks) - .map(move |item| { - Ok::<_, ()>(IPeerManagerMessage::SendMessage( - peer, - 0, - PeerWireProtocolMessage::Request(item), - )) - }) - .collect() - } else { - vec![] - }; - - // First, send any control messages, then, send any more piece requests - Box::new( - map_peer_manager_send - .send_all(stream::iter_result(send_messages.into_iter().map(Ok::<_, ()>))) - .map_err(|()| ()) - .and_then(|(map_peer_manager_send, _)| { - map_peer_manager_send.send_all(stream::iter_result(next_piece_requests)) - }) - .map_err(|()| ()) - .map(move |(map_peer_manager_send, _)| { - Loop::Continue(( - select_recv, - map_peer_manager_send, - piece_requests, - opt_peer, - unchoked, - blocks_pending, - cur_pieces, - total_pieces, - )) - }), - ) - } else { - // Not done yet, and we don't have any peer info stored (haven't received the peer yet) - Box::new(future::ok(Loop::Continue(( - select_recv, - map_peer_manager_send, - piece_requests, - opt_peer, - unchoked, - blocks_pending, - cur_pieces, - total_pieces, - )))) - }; - - result - }) - }, - )); - - result.unwrap(); -} - -/// Generate a mapping of piece index to list of block requests for that piece, given a block size. -/// -/// Note, most clients will drop connections for peers requesting block sizes above 16KB. -fn generate_requests(info: &Info, block_size: usize) -> Vec { - let mut requests = Vec::new(); - - // Grab our piece length, and the sum of the lengths of each file in the torrent - let piece_len: u64 = info.piece_length(); - let mut total_file_length: u64 = info.files().map(metainfo::File::length).sum(); - - // Loop over each piece (keep subtracting total file length by piece size, use cmp::min to handle last, smaller piece) - let mut piece_index: u64 = 0; - while total_file_length != 0 { - let next_piece_len = cmp::min(total_file_length, piece_len); - - // For all whole blocks, push the block index and block_size - let whole_blocks = next_piece_len / block_size as u64; - for block_index in 0..whole_blocks { - let block_offset = block_index * block_size as u64; - - #[allow(clippy::cast_possible_truncation)] - requests.push(RequestMessage::new(piece_index as u32, block_offset as u32, block_size)); - } - - // Check for any last smaller block within the current piece - let partial_block_length = next_piece_len % block_size as u64; - if partial_block_length != 0 { - let block_offset = whole_blocks * block_size as u64; - - requests.push(RequestMessage::new( - piece_index.try_into().unwrap(), - block_offset.try_into().unwrap(), - partial_block_length.try_into().unwrap(), - )); - } - - // Take this piece out of the total length, increment to the next piece - total_file_length -= next_piece_len; - piece_index += 1; - } - - requests -} diff --git a/packages/dht/examples/debug.rs b/packages/dht/examples/debug.rs index df7e3d5ad..5651fbdcc 100644 --- a/packages/dht/examples/debug.rs +++ b/packages/dht/examples/debug.rs @@ -7,6 +7,7 @@ use dht::handshaker_trait::HandshakerTrait; use dht::{DhtBuilder, Router}; use futures::future::BoxFuture; use futures::StreamExt; +use tokio::task::JoinSet; use tracing::level_filters::LevelFilter; use util::bt::{InfoHash, PeerId}; @@ -65,6 +66,8 @@ async fn main() { tracing_stderr_init(LevelFilter::INFO); }); + let mut tasks = JoinSet::new(); + let hash = InfoHash::from_bytes(b"My Unique Info Hash"); let handshaker = SimpleHandshaker { @@ -80,7 +83,7 @@ async fn main() { // Spawn a thread to listen to and report events let mut events = dht.events().await; - tokio::spawn(async move { + tasks.spawn(async move { while let Some(event) = events.next().await { println!("\nReceived Dht Event {event:?}"); } diff --git a/packages/dht/src/builder.rs b/packages/dht/src/builder.rs index 6e043fd6a..93c640511 100644 --- a/packages/dht/src/builder.rs +++ b/packages/dht/src/builder.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use futures::channel::mpsc::{self, Receiver, Sender}; use futures::SinkExt as _; use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use util::net; @@ -15,6 +16,7 @@ use crate::worker::{self, DhtEvent, OneshotTask, ShutdownCause}; /// Maintains a Distributed Hash (Routing) Table. pub struct MainlineDht { main_task_sender: Sender, + _tasks: JoinSet<()>, } impl MainlineDht { @@ -29,7 +31,7 @@ impl MainlineDht { let kill_sock = send_sock.clone(); let kill_addr = send_sock.local_addr()?; - let main_task_sender = worker::start_mainline_dht( + let (main_task_sender, tasks) = worker::start_mainline_dht( &send_sock, recv_sock, builder.read_only, @@ -51,7 +53,10 @@ impl MainlineDht { tracing::warn!("bip_dt: MainlineDht failed to send a start bootstrap message..."); } - Ok(MainlineDht { main_task_sender }) + Ok(MainlineDht { + main_task_sender, + _tasks: tasks, + }) } /// Perform a search for the given `InfoHash` with an optional announce on the closest nodes. diff --git a/packages/dht/src/worker/bootstrap.rs b/packages/dht/src/worker/bootstrap.rs index 0493f33e6..7bcb41f77 100644 --- a/packages/dht/src/worker/bootstrap.rs +++ b/packages/dht/src/worker/bootstrap.rs @@ -3,9 +3,10 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; -use futures::channel::mpsc::Sender; +use futures::channel::mpsc::{SendError, Sender}; use futures::future::BoxFuture; use futures::{FutureExt as _, SinkExt as _}; +use tokio::task::JoinSet; use tokio::time::{sleep, Duration}; use util::bt::{self, NodeId}; @@ -43,6 +44,7 @@ pub struct TableBootstrap { active_messages: Mutex>, starting_routers: HashSet, curr_bootstrap_bucket: AtomicUsize, + tasks: Arc>>>, } impl TableBootstrap { @@ -59,6 +61,7 @@ impl TableBootstrap { starting_routers: router_filter, active_messages: Mutex::default(), curr_bootstrap_bucket: AtomicUsize::default(), + tasks: Arc::default(), } } @@ -66,7 +69,7 @@ impl TableBootstrap { &self, mut out: Sender<(Vec, SocketAddr)>, mut scheduled_task_sender: Sender, - ) -> BootstrapStatus { + ) -> Result { // Reset the bootstrap state self.active_messages.lock().unwrap().clear(); self.curr_bootstrap_bucket.store(0, Ordering::Relaxed); @@ -75,14 +78,21 @@ impl TableBootstrap { let trans_id = self.id_generator.lock().unwrap().generate(); // Set a timer to begin the actual bootstrap - tokio::spawn(async move { + let abort = self.tasks.lock().unwrap().spawn(async move { sleep(Duration::from_millis(BOOTSTRAP_INITIAL_TIMEOUT)).await; - if scheduled_task_sender + + match scheduled_task_sender .send(ScheduledTaskCheck::BootstrapTimeout(trans_id)) .await - .is_err() { - tracing::error!("bip_dht: Failed to send scheduled task check for bootstrap timeout"); + Ok(()) => { + tracing::debug!("sent scheduled bootstrap timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled bootstrap timeout: {e}"); + Err(e) + } } }); @@ -97,11 +107,12 @@ impl TableBootstrap { for addr in self.starting_routers.iter().chain(self.starting_nodes.iter()) { if out.send((find_node_msg.clone(), *addr)).await.is_err() { tracing::error!("bip_dht: Failed to send bootstrap message to router through channel..."); - return BootstrapStatus::Failed; + abort.abort(); + return Err(BootstrapStatus::Failed); } } - self.current_bootstrap_status() + Ok(self.current_bootstrap_status()) } pub fn is_router(&self, addr: &SocketAddr) -> bool { @@ -302,15 +313,22 @@ impl TableBootstrap { messages_sent += 1; // Schedule a timeout check - let mut task_sender_clone = scheduled_task_sender.clone(); - tokio::spawn(async move { - sleep(Duration::from_millis(BOOTSTRAP_NODE_TIMEOUT)).await; - if task_sender_clone + let mut this_scheduled_task_sender = scheduled_task_sender.clone(); + self.tasks.lock().unwrap().spawn(async move { + sleep(Duration::from_millis(BOOTSTRAP_INITIAL_TIMEOUT)).await; + + match this_scheduled_task_sender .send(ScheduledTaskCheck::BootstrapTimeout(trans_id)) .await - .is_err() { - tracing::error!("bip_dht: Failed to send scheduled task check for bootstrap timeout"); + Ok(()) => { + tracing::debug!("sent scheduled bootstrap timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled bootstrap timeout: {e}"); + Err(e) + } } }); } diff --git a/packages/dht/src/worker/handler.rs b/packages/dht/src/worker/handler.rs index 3dda90f6f..8d5bd8322 100644 --- a/packages/dht/src/worker/handler.rs +++ b/packages/dht/src/worker/handler.rs @@ -9,6 +9,7 @@ use futures::channel::mpsc::{self, Sender}; use futures::future::BoxFuture; use futures::{FutureExt, SinkExt, StreamExt as _}; use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use util::convert; use util::net::IpAddr; @@ -51,7 +52,7 @@ pub fn create_dht_handler( handshaker: H, kill_sock: Arc, kill_addr: SocketAddr, -) -> Sender +) -> (Sender, JoinSet<()>) where H: HandshakerTrait + 'static, { @@ -72,7 +73,9 @@ where handshaker, ); - tokio::spawn(async move { + let mut tasks = JoinSet::new(); + + tasks.spawn(async move { while let Some(task) = tasks_receiver.next().await { match task { Task::Main(main_task) => handler.handle_task(main_task).await, @@ -90,7 +93,7 @@ where tracing::info!("bip_dht: DhtHandler gracefully shut down, exiting thread..."); }); - main_task_sender + (main_task_sender, tasks) } // ----------------------------------------------------------------------------// @@ -655,36 +658,13 @@ where ); match bootstrap_status { - BootstrapStatus::Idle => true, - BootstrapStatus::Bootstrapping => false, - BootstrapStatus::Failed => { + Ok(BootstrapStatus::Idle) => true, + Ok(BootstrapStatus::Bootstrapping) => false, + Err(BootstrapStatus::Failed) => { self.handle_shutdown(ShutdownCause::Unspecified); false } - BootstrapStatus::Completed => { - // Check if our bootstrap was actually good - - if should_rebootstrap(&self.routing_table.read().unwrap()) { - let Some(TableAction::Bootstrap(bootstrap, attempts)) = - self.table_actions.lock().unwrap().get_mut(&action_id).cloned() - else { - panic!("bip_dht: Bug, in DhtHandler...") - }; - - attempt_rebootstrap( - bootstrap, - attempts, - self.routing_table.clone(), - self.out_channel.clone(), - self.main_task_sender.clone(), - self.scheduled_task_sender.clone(), - ) - .await - == Some(false) - } else { - true - } - } + Ok(_) | Err(_) => unreachable!(), } }; @@ -1048,27 +1028,13 @@ fn attempt_rebootstrap( let bootstrap_status = bootstrap.start_bootstrap(out.clone(), scheduled_task_sender.clone()).await; match bootstrap_status { - BootstrapStatus::Idle => Some(false), - BootstrapStatus::Bootstrapping => Some(true), - BootstrapStatus::Failed => { + Ok(BootstrapStatus::Idle) => Some(false), + Ok(BootstrapStatus::Bootstrapping) => Some(true), + Err(BootstrapStatus::Failed) => { shutdown_event_loop(main_task_sender, ShutdownCause::Unspecified).await; None } - BootstrapStatus::Completed => { - if should_rebootstrap(&routing_table.read().unwrap()) { - attempt_rebootstrap( - bootstrap, - attempts, - routing_table.clone(), - out, - main_task_sender, - scheduled_task_sender, - ) - .await - } else { - Some(false) - } - } + Ok(_) | Err(_) => unreachable!(), } } } diff --git a/packages/dht/src/worker/lookup.rs b/packages/dht/src/worker/lookup.rs index 163737660..60f46498a 100644 --- a/packages/dht/src/worker/lookup.rs +++ b/packages/dht/src/worker/lookup.rs @@ -5,9 +5,10 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use bencode::BRefAccess; -use futures::channel::mpsc::Sender; +use futures::channel::mpsc::{SendError, Sender}; use futures::future::BoxFuture; use futures::{FutureExt, SinkExt as _}; +use tokio::task::JoinSet; use tokio::time::{sleep, Duration, Instant}; use util::bt::{self, InfoHash, NodeId}; use util::net; @@ -52,6 +53,7 @@ pub struct TableLookup { announce_tokens: Mutex>>, requested_nodes: Mutex>, all_sorted_nodes: Mutex)>>, + tasks: Arc>>>, } impl TableLookup { @@ -95,6 +97,7 @@ impl TableLookup { announce_tokens: Mutex::new(HashMap::new()), requested_nodes: Mutex::new(HashSet::new()), active_lookups: Mutex::new(HashMap::with_capacity(INITIAL_PICK_NUM)), + tasks: Arc::default(), }; if table_lookup @@ -364,15 +367,22 @@ impl TableLookup { messages_sent += 1; // Schedule a timeout check - let mut task_sender_clone = scheduled_task_sender.clone(); - tokio::spawn(async move { + let mut this_scheduled_task_sender = scheduled_task_sender.clone(); + self.tasks.lock().unwrap().spawn(async move { sleep(Duration::from_millis(LOOKUP_TIMEOUT_MS)).await; - if task_sender_clone + + match this_scheduled_task_sender .send(ScheduledTaskCheck::LookupTimeout(trans_id)) .await - .is_err() { - tracing::error!("bip_dht: Failed to send scheduled task check for lookup timeout"); + Ok(()) => { + tracing::debug!("sent scheduled lookup timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled lookup timeout: {e}"); + Err(e) + } } }); } diff --git a/packages/dht/src/worker/mod.rs b/packages/dht/src/worker/mod.rs index c686eb034..9a78924ad 100644 --- a/packages/dht/src/worker/mod.rs +++ b/packages/dht/src/worker/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use futures::channel::mpsc::Sender; use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use crate::handshaker_trait::HandshakerTrait; @@ -77,7 +78,7 @@ pub fn start_mainline_dht( handshaker: H, kill_sock: Arc, kill_addr: SocketAddr, -) -> Sender +) -> (Sender, JoinSet<()>) where H: HandshakerTrait + 'static, { @@ -87,7 +88,7 @@ where let routing_table = RoutingTable::new(table::random_node_id()); let message_sender = handler::create_dht_handler(routing_table, outgoing, read_only, handshaker, kill_sock, kill_addr); - messenger::create_incoming_messenger(recv_socket, message_sender.clone()); + messenger::create_incoming_messenger(recv_socket, message_sender.0.clone()); message_sender } diff --git a/packages/dht/src/worker/refresh.rs b/packages/dht/src/worker/refresh.rs index 33ed6c71d..69bce12fc 100644 --- a/packages/dht/src/worker/refresh.rs +++ b/packages/dht/src/worker/refresh.rs @@ -2,8 +2,9 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; -use futures::channel::mpsc::Sender; +use futures::channel::mpsc::{SendError, Sender}; use futures::SinkExt as _; +use tokio::task::JoinSet; use tokio::time::{sleep, Duration}; use util::bt::{self, NodeId}; @@ -27,6 +28,7 @@ pub enum RefreshStatus { pub struct TableRefresh { id_generator: Mutex, curr_refresh_bucket: AtomicUsize, + tasks: Arc>>>, } impl TableRefresh { @@ -34,6 +36,7 @@ impl TableRefresh { TableRefresh { id_generator: Mutex::new(id_generator), curr_refresh_bucket: AtomicUsize::default(), + tasks: Arc::default(), } } @@ -91,14 +94,18 @@ impl TableRefresh { let trans_id = self.id_generator.lock().unwrap().generate(); // Start a timer for the next refresh - tokio::spawn(async move { + self.tasks.lock().unwrap().spawn(async move { sleep(Duration::from_millis(REFRESH_INTERVAL_TIMEOUT)).await; - if scheduled_task_sender - .send(ScheduledTaskCheck::TableRefresh(trans_id)) - .await - .is_err() - { - tracing::error!("bip_dht: Failed to send scheduled task check for table refresh"); + + match scheduled_task_sender.send(ScheduledTaskCheck::TableRefresh(trans_id)).await { + Ok(()) => { + tracing::debug!("sent scheduled refresh timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled refresh timeout: {e}"); + Err(e) + } } });