Skip to content

Commit

Permalink
Do DNS queries for both A and AAAA simultaneously
Browse files Browse the repository at this point in the history
We implement a basic version of RFC8305 (happy eyeballs) to establish
the connection afterwards.
  • Loading branch information
r-vdp committed Jul 10, 2024
1 parent 4f570dc commit 332df2a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
19 changes: 18 additions & 1 deletion src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::tcp;
use anyhow::{anyhow, Context};
use futures_util::{FutureExt, TryFutureExt};
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
use hickory_resolver::proto::TokioTime;
Expand All @@ -15,6 +15,22 @@ use std::time::Duration;
use tokio::net::{TcpStream, UdpSocket};
use url::{Host, Url};

// Interweave v4 and v6 addresses as per RFC8305.
// The first address is v6 if we have any v6 addresses.
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
let mut pick_v6 = false;
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
std::iter::from_fn(move || {
pick_v6 = !pick_v6;
if pick_v6 {
v6.next().or_else(|| v4.next())
} else {
v4.next().or_else(|| v6.next())
}
})
}

#[derive(Clone)]
pub enum DnsResolver {
System,
Expand Down Expand Up @@ -112,6 +128,7 @@ impl DnsResolver {

let mut opts = ResolverOpts::default();
opts.timeout = std::time::Duration::from_secs(1);
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
Expand Down
70 changes: 61 additions & 9 deletions src/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::{anyhow, Context};
use std::{io, vec};
use tokio::task::JoinSet;

use crate::dns::DnsResolver;
use crate::dns::{self, DnsResolver};
use base64::Engine;
use bytes::BytesMut;
use log::warn;
Expand All @@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::time::timeout;
use tokio::time::{sleep, timeout};
use tokio_stream::wrappers::TcpListenerStream;
use tracing::log::info;
use tracing::{debug, instrument};
Expand Down Expand Up @@ -70,7 +71,9 @@ pub async fn connect(

let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
let mut join_set = JoinSet::new();

for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
debug!("Connecting to {}", addr);

let socket = match &addr {
Expand All @@ -79,16 +82,45 @@ pub async fn connect(
};

configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
match timeout(connect_timeout, socket.connect(addr)).await {

// Spawn the connection attempt in the join set.
// We include a delay of ix * 250 milliseconds, as per RFC8305.
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
let fut = async move {
if ix > 0 {
sleep(Duration::from_millis(250 * ix as u64)).await;
}
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(s)) => Ok(Ok(s)),
Ok(Err(e)) => Ok(Err((addr, e))),
Err(e) => Err((addr, e)),
}
};
join_set.spawn(fut);
}

// Wait for the next future that finishes in the join set, until we got one
// that resulted in a successful connection.
// If cnx is no longer None, we exit the loop, since this means that we got
// a successful connection.
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
match res? {
Ok(Ok(stream)) => {
// We've got a successful connection, so we can abort all other
// on-going attempts.
join_set.abort_all();

debug!(
"Connected to tcp endpoint {}, aborted all other connection attempts",
stream.peer_addr()?
);
cnx = Some(stream);
break;
}
Ok(Err(err)) => {
warn!("Cannot connect to tcp endpoint {addr} reason {err}");
Ok(Err((addr, err))) => {
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err(_) => {
Err((addr, _)) => {
warn!(
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
Expand Down Expand Up @@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
mod tests {
use super::*;
use futures_util::pin_mut;
use std::net::SocketAddr;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use testcontainers::core::WaitFor;
use testcontainers::runners::AsyncRunner;
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
Expand Down Expand Up @@ -227,6 +259,26 @@ mod tests {
}
}

#[test]
fn test_sort_socket_addrs() {
let addrs = [
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
];
let expected = [
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
assert_eq!(expected, *actual);
}

#[tokio::test]
async fn test_proxy_connection() {
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;

use crate::dns::DnsResolver;
use crate::dns::{self, DnsResolver};
use tokio::sync::Notify;
use tokio::time::{timeout, Interval};
use tracing::{debug, error, info};
Expand Down Expand Up @@ -337,7 +337,7 @@ pub async fn connect(

let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
for addr in dns::sort_socket_addrs(&socket_addrs) {
debug!("connecting to {}", addr);

let socket = match &addr {
Expand Down

0 comments on commit 332df2a

Please sign in to comment.