Skip to content

Commit

Permalink
fix(mdns): move IO off main task
Browse files Browse the repository at this point in the history
Resolves: #2591.

Pull-Request: #4623.
  • Loading branch information
thomaseizinger authored Oct 20, 2023
1 parent d26e04a commit 0181e86
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 93 deletions.
2 changes: 2 additions & 0 deletions protocols/mdns/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## 0.45.0 - unreleased

- Don't perform IO in `Behaviour::poll`.
See [PR 4623](https://github.com/libp2p/rust-libp2p/pull/4623).

## 0.44.0

Expand Down
3 changes: 2 additions & 1 deletion protocols/mdns/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"]
categories = ["network-programming", "asynchronous"]

[dependencies]
async-std = { version = "1.12.0", optional = true }
async-io = { version = "1.13.0", optional = true }
data-encoding = "2.4.0"
futures = "0.3.28"
Expand All @@ -28,7 +29,7 @@ void = "1.0.2"

[features]
tokio = ["dep:tokio", "if-watch/tokio"]
async-io = ["dep:async-io", "if-watch/smol"]
async-io = ["dep:async-io", "dep:async-std", "if-watch/smol"]

[dev-dependencies]
async-std = { version = "1.9.0", features = ["attributes"] }
Expand Down
136 changes: 88 additions & 48 deletions protocols/mdns/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ mod timer;
use self::iface::InterfaceState;
use crate::behaviour::{socket::AsyncSocket, timer::Builder};
use crate::Config;
use futures::Stream;
use futures::channel::mpsc;
use futures::{Stream, StreamExt};
use if_watch::IfEvent;
use libp2p_core::{Endpoint, Multiaddr};
use libp2p_identity::PeerId;
Expand All @@ -36,6 +37,8 @@ use libp2p_swarm::{
};
use smallvec::SmallVec;
use std::collections::hash_map::{Entry, HashMap};
use std::future::Future;
use std::sync::{Arc, RwLock};
use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant};

/// An abstraction to allow for compatibility with various async runtimes.
Expand All @@ -47,16 +50,27 @@ pub trait Provider: 'static {
/// The IfWatcher type.
type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;

type TaskHandle: Abort;

/// Create a new instance of the `IfWatcher` type.
fn new_watcher() -> Result<Self::Watcher, std::io::Error>;

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
}

#[allow(unreachable_pub)] // Not re-exported.
pub trait Abort {
fn abort(self);
}

/// The type of a [`Behaviour`] using the `async-io` implementation.
#[cfg(feature = "async-io")]
pub mod async_io {
use super::Provider;
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer};
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
use async_std::task::JoinHandle;
use if_watch::smol::IfWatcher;
use std::future::Future;

#[doc(hidden)]
pub enum AsyncIo {}
Expand All @@ -65,10 +79,21 @@ pub mod async_io {
type Socket = AsyncUdpSocket;
type Timer = AsyncTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
async_std::task::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
async_std::task::spawn(self.cancel());
}
}

pub type Behaviour = super::Behaviour<AsyncIo>;
Expand All @@ -78,8 +103,10 @@ pub mod async_io {
#[cfg(feature = "tokio")]
pub mod tokio {
use super::Provider;
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer};
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
use if_watch::tokio::IfWatcher;
use std::future::Future;
use tokio::task::JoinHandle;

#[doc(hidden)]
pub enum Tokio {}
Expand All @@ -88,10 +115,21 @@ pub mod tokio {
type Socket = TokioUdpSocket;
type Timer = TokioTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
tokio::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
JoinHandle::abort(&self)
}
}

pub type Behaviour = super::Behaviour<Tokio>;
Expand All @@ -110,8 +148,11 @@ where
/// Iface watcher.
if_watch: P::Watcher,

/// Mdns interface states.
iface_states: HashMap<IpAddr, InterfaceState<P::Socket, P::Timer>>,
/// Handles to tasks running the mDNS queries.
if_tasks: HashMap<IpAddr, P::TaskHandle>,

query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,

/// List of nodes that we have discovered, the address, and when their TTL expires.
///
Expand All @@ -124,7 +165,11 @@ where
/// `None` if `discovered_nodes` is empty.
closest_expiration: Option<P::Timer>,

listen_addresses: ListenAddresses,
/// The current set of listen addresses.
///
/// This is shared across all interface tasks using an [`RwLock`].
/// The [`Behaviour`] updates this upon new [`FromSwarm`] events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
listen_addresses: Arc<RwLock<ListenAddresses>>,

local_peer_id: PeerId,
}
Expand All @@ -135,10 +180,14 @@ where
{
/// Builds a new `Mdns` behaviour.
pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily.

Ok(Self {
config,
if_watch: P::new_watcher()?,
iface_states: Default::default(),
if_tasks: Default::default(),
query_response_receiver: rx,
query_response_sender: tx,
discovered_nodes: Default::default(),
closest_expiration: Default::default(),
listen_addresses: Default::default(),
Expand All @@ -147,6 +196,7 @@ where
}

/// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
#[deprecated(note = "Use `discovered_nodes` iterator instead.")]
pub fn has_node(&self, peer_id: &PeerId) -> bool {
self.discovered_nodes().any(|p| p == peer_id)
}
Expand All @@ -157,6 +207,7 @@ where
}

/// Expires a node before the ttl.
#[deprecated(note = "Unused API. Will be removed in the next release.")]
pub fn expire_node(&mut self, peer_id: &PeerId) {
let now = Instant::now();
for (peer, _addr, expires) in &mut self.discovered_nodes {
Expand Down Expand Up @@ -225,28 +276,10 @@ where
}

fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
self.listen_addresses.on_swarm_event(&event);

match event {
FromSwarm::NewListener(_) => {
log::trace!("waking interface state because listening address changed");
for iface in self.iface_states.values_mut() {
iface.fire_timer();
}
}
FromSwarm::ConnectionClosed(_)
| FromSwarm::ConnectionEstablished(_)
| FromSwarm::DialFailure(_)
| FromSwarm::AddressChange(_)
| FromSwarm::ListenFailure(_)
| FromSwarm::NewListenAddr(_)
| FromSwarm::ExpiredListenAddr(_)
| FromSwarm::ListenerError(_)
| FromSwarm::ListenerClosed(_)
| FromSwarm::NewExternalAddrCandidate(_)
| FromSwarm::ExternalAddrExpired(_)
| FromSwarm::ExternalAddrConfirmed(_) => {}
}
self.listen_addresses
.write()
.unwrap_or_else(|e| e.into_inner())
.on_swarm_event(&event);
}

fn poll(
Expand All @@ -267,43 +300,50 @@ where
{
continue;
}
if let Entry::Vacant(e) = self.iface_states.entry(addr) {
match InterfaceState::new(addr, self.config.clone(), self.local_peer_id) {
if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
match InterfaceState::<P::Socket, P::Timer>::new(
addr,
self.config.clone(),
self.local_peer_id,
self.listen_addresses.clone(),
self.query_response_sender.clone(),
) {
Ok(iface_state) => {
e.insert(iface_state);
e.insert(P::spawn(iface_state));
}
Err(err) => log::error!("failed to create `InterfaceState`: {}", err),
}
}
}
Ok(IfEvent::Down(inet)) => {
if self.iface_states.contains_key(&inet.addr()) {
if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
log::info!("dropping instance {}", inet.addr());
self.iface_states.remove(&inet.addr());

handle.abort();
}
}
Err(err) => log::error!("if watch returned an error: {}", err),
}
}
// Emit discovered event.
let mut discovered = Vec::new();
for iface_state in self.iface_states.values_mut() {
while let Poll::Ready((peer, addr, expiration)) =
iface_state.poll(cx, &self.listen_addresses)

while let Poll::Ready(Some((peer, addr, expiration))) =
self.query_response_receiver.poll_next_unpin(cx)
{
if let Some((_, _, cur_expires)) = self
.discovered_nodes
.iter_mut()
.find(|(p, a, _)| *p == peer && *a == addr)
{
if let Some((_, _, cur_expires)) = self
.discovered_nodes
.iter_mut()
.find(|(p, a, _)| *p == peer && *a == addr)
{
*cur_expires = cmp::max(*cur_expires, expiration);
} else {
log::info!("discovered: {} {}", peer, addr);
self.discovered_nodes.push((peer, addr.clone(), expiration));
discovered.push((peer, addr));
}
*cur_expires = cmp::max(*cur_expires, expiration);
} else {
log::info!("discovered: {} {}", peer, addr);
self.discovered_nodes.push((peer, addr.clone(), expiration));
discovered.push((peer, addr));
}
}

if !discovered.is_empty() {
let event = Event::Discovered(discovered);
return Poll::Ready(ToSwarm::GenerateEvent(event));
Expand Down
Loading

0 comments on commit 0181e86

Please sign in to comment.