diff --git a/src/dns.rs b/src/dns.rs index 6a7a2ea3..e8728e08 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,7 +1,7 @@ use crate::tcp; use anyhow::{anyhow, Context}; use futures_util::{FutureExt, TryFutureExt}; -use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; +use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider}; use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd; use hickory_resolver::proto::TokioTime; @@ -15,6 +15,22 @@ use std::time::Duration; use tokio::net::{TcpStream, UdpSocket}; use url::{Host, Url}; +// Interweave v4 and v6 addresses as per RFC8305. +// The first address is v6 if we have any v6 addresses. +pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator { + let mut pick_v6 = false; + let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_))); + let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_))); + std::iter::from_fn(move || { + pick_v6 = !pick_v6; + if pick_v6 { + v6.next().or_else(|| v4.next()) + } else { + v4.next().or_else(|| v6.next()) + } + }) +} + #[derive(Clone)] pub enum DnsResolver { System, @@ -112,6 +128,7 @@ impl DnsResolver { let mut opts = ResolverOpts::default(); opts.timeout = std::time::Duration::from_secs(1); + opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; Ok(Self::TrustDns(AsyncResolver::new( cfg, opts, diff --git a/src/tcp.rs b/src/tcp.rs index 07d6037d..b9c80012 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,7 +1,8 @@ use anyhow::{anyhow, Context}; use std::{io, vec}; +use tokio::task::JoinSet; -use crate::dns::DnsResolver; +use crate::dns::{self, DnsResolver}; use base64::Engine; use bytes::BytesMut; use log::warn; @@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; -use tokio::time::timeout; +use tokio::time::{sleep, timeout}; use tokio_stream::wrappers::TcpListenerStream; use tracing::log::info; use tracing::{debug, instrument}; @@ -70,7 +71,9 @@ pub async fn connect( let mut cnx = None; let mut last_err = None; - for addr in socket_addrs { + let mut join_set = JoinSet::new(); + + for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() { debug!("Connecting to {}", addr); let socket = match &addr { @@ -79,16 +82,45 @@ pub async fn connect( }; configure_socket(socket2::SockRef::from(&socket), &so_mark)?; - match timeout(connect_timeout, socket.connect(addr)).await { + + // Spawn the connection attempt in the join set. + // We include a delay of ix * 250 milliseconds, as per RFC8305. + // See https://datatracker.ietf.org/doc/html/rfc8305#section-5 + let fut = async move { + if ix > 0 { + sleep(Duration::from_millis(250 * ix as u64)).await; + } + match timeout(connect_timeout, socket.connect(addr)).await { + Ok(Ok(s)) => Ok(Ok(s)), + Ok(Err(e)) => Ok(Err((addr, e))), + Err(e) => Err((addr, e)), + } + }; + join_set.spawn(fut); + } + + // Wait for the next future that finishes in the join set, until we got one + // that resulted in a successful connection. + // If cnx is no longer None, we exit the loop, since this means that we got + // a successful connection. + while let (None, Some(res)) = (&cnx, join_set.join_next().await) { + match res? { Ok(Ok(stream)) => { + // We've got a successful connection, so we can abort all other + // on-going attempts. + join_set.abort_all(); + + debug!( + "Connected to tcp endpoint {}, aborted all other connection attempts", + stream.peer_addr()? + ); cnx = Some(stream); - break; } - Ok(Err(err)) => { - warn!("Cannot connect to tcp endpoint {addr} reason {err}"); + Ok(Err((addr, err))) => { + debug!("Cannot connect to tcp endpoint {addr} reason {err}"); last_err = Some(err); } - Err(_) => { + Err((addr, _)) => { warn!( "Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed", connect_timeout.as_secs() @@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result = dns::sort_socket_addrs(&addrs).copied().collect(); + assert_eq!(expected, *actual); + } + #[tokio::test] async fn test_proxy_connection() { let server_addr: SocketAddr = "[::1]:1236".parse().unwrap(); diff --git a/src/udp.rs b/src/udp.rs index 082c4f97..6e6deee1 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -8,6 +8,7 @@ use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use tokio::task::JoinSet; use log::warn; use std::pin::{pin, Pin}; @@ -18,9 +19,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; use tokio::sync::futures::Notified; -use crate::dns::DnsResolver; +use crate::dns::{self, DnsResolver}; use tokio::sync::Notify; -use tokio::time::{timeout, Interval}; +use tokio::time::{sleep, timeout, Interval}; use tracing::{debug, error, info}; use url::Host; @@ -337,7 +338,9 @@ pub async fn connect( let mut cnx = None; let mut last_err = None; - for addr in socket_addrs { + let mut join_set = JoinSet::new(); + + for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() { debug!("connecting to {}", addr); let socket = match &addr { @@ -353,18 +356,47 @@ pub async fn connect( } }; - match timeout(connect_timeout, socket.connect(addr)).await { - Ok(Ok(_)) => { + // Spawn the connection attempt in the join set. + // We include a delay of ix * 250 milliseconds, as per RFC8305. + // See https://datatracker.ietf.org/doc/html/rfc8305#section-5 + let fut = async move { + if ix > 0 { + sleep(Duration::from_millis(250 * ix as u64)).await; + } + + match timeout(connect_timeout, socket.connect(addr)).await { + Ok(Ok(())) => Ok(Ok(socket)), + Ok(Err(e)) => Ok(Err((addr, e))), + Err(e) => Err((addr, e)), + } + }; + join_set.spawn(fut); + } + + // Wait for the next future that finishes in the join set, until we got one + // that resulted in a successful connection. + // If cnx is no longer None, we exit the loop, since this means that we got + // a successful connection. + while let (None, Some(res)) = (&cnx, join_set.join_next().await) { + match res? { + Ok(Ok(socket)) => { + // We've got a successful connection, so we can abort all other + // on-going attempts. + join_set.abort_all(); + + debug!( + "Connected to udp endpoint {}, aborted all other connection attempts", + socket.peer_addr()? + ); cnx = Some(socket); - break; } - Ok(Err(err)) => { - debug!("Cannot connect udp socket to specified peer {addr} reason {err}"); + Ok(Err((addr, err))) => { + debug!("Cannot connect to udp endpoint {addr} reason {err}"); last_err = Some(err); } - Err(_) => { - debug!( - "Cannot connect udp socket to specified peer {addr} due to timeout of {}s elapsed", + Err((addr, _)) => { + warn!( + "Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed", connect_timeout.as_secs() ); }