Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do DNS queries for both A and AAAA simultaneously #302

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::tcp;
use anyhow::{anyhow, Context};
use futures_util::{FutureExt, TryFutureExt};
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
use hickory_resolver::proto::TokioTime;
Expand All @@ -15,6 +15,22 @@ use std::time::Duration;
use tokio::net::{TcpStream, UdpSocket};
use url::{Host, Url};

// Interweave v4 and v6 addresses as per RFC8305.
// The first address is v6 if we have any v6 addresses.
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
let mut pick_v6 = false;
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
std::iter::from_fn(move || {
pick_v6 = !pick_v6;
if pick_v6 {
v6.next().or_else(|| v4.next())
} else {
v4.next().or_else(|| v6.next())
}
})
}

#[derive(Clone)]
pub enum DnsResolver {
System,
Expand Down Expand Up @@ -112,6 +128,7 @@ impl DnsResolver {

let mut opts = ResolverOpts::default();
opts.timeout = std::time::Duration::from_secs(1);
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
Expand Down
70 changes: 61 additions & 9 deletions src/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::{anyhow, Context};
use std::{io, vec};
use tokio::task::JoinSet;

use crate::dns::DnsResolver;
use crate::dns::{self, DnsResolver};
use base64::Engine;
use bytes::BytesMut;
use log::warn;
Expand All @@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::time::timeout;
use tokio::time::{sleep, timeout};
use tokio_stream::wrappers::TcpListenerStream;
use tracing::log::info;
use tracing::{debug, instrument};
Expand Down Expand Up @@ -70,7 +71,9 @@ pub async fn connect(

let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
let mut join_set = JoinSet::new();

for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
debug!("Connecting to {}", addr);

let socket = match &addr {
Expand All @@ -79,16 +82,45 @@ pub async fn connect(
};

configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
match timeout(connect_timeout, socket.connect(addr)).await {

// Spawn the connection attempt in the join set.
// We include a delay of ix * 250 milliseconds, as per RFC8305.
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
let fut = async move {
if ix > 0 {
sleep(Duration::from_millis(250 * ix as u64)).await;
}
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(s)) => Ok(Ok(s)),
Ok(Err(e)) => Ok(Err((addr, e))),
Err(e) => Err((addr, e)),
}
};
join_set.spawn(fut);
}

// Wait for the next future that finishes in the join set, until we got one
// that resulted in a successful connection.
// If cnx is no longer None, we exit the loop, since this means that we got
// a successful connection.
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
match res? {
Ok(Ok(stream)) => {
// We've got a successful connection, so we can abort all other
// on-going attempts.
join_set.abort_all();

debug!(
"Connected to tcp endpoint {}, aborted all other connection attempts",
stream.peer_addr()?
);
cnx = Some(stream);
break;
}
Ok(Err(err)) => {
warn!("Cannot connect to tcp endpoint {addr} reason {err}");
Ok(Err((addr, err))) => {
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err(_) => {
Err((addr, _)) => {
warn!(
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
Expand Down Expand Up @@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
mod tests {
use super::*;
use futures_util::pin_mut;
use std::net::SocketAddr;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use testcontainers::core::WaitFor;
use testcontainers::runners::AsyncRunner;
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
Expand Down Expand Up @@ -227,6 +259,26 @@ mod tests {
}
}

#[test]
fn test_sort_socket_addrs() {
let addrs = [
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
];
let expected = [
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
assert_eq!(expected, *actual);
}

#[tokio::test]
async fn test_proxy_connection() {
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
Expand Down
54 changes: 43 additions & 11 deletions src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::future::Future;
use std::io;
use std::io::{Error, ErrorKind};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use tokio::task::JoinSet;

use log::warn;
use std::pin::{pin, Pin};
Expand All @@ -18,9 +19,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;

use crate::dns::DnsResolver;
use crate::dns::{self, DnsResolver};
use tokio::sync::Notify;
use tokio::time::{timeout, Interval};
use tokio::time::{sleep, timeout, Interval};
use tracing::{debug, error, info};
use url::Host;

Expand Down Expand Up @@ -337,7 +338,9 @@ pub async fn connect(

let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
let mut join_set = JoinSet::new();

for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
debug!("connecting to {}", addr);

let socket = match &addr {
Expand All @@ -353,18 +356,47 @@ pub async fn connect(
}
};

match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(_)) => {
// Spawn the connection attempt in the join set.
// We include a delay of ix * 250 milliseconds, as per RFC8305.
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
let fut = async move {
if ix > 0 {
sleep(Duration::from_millis(250 * ix as u64)).await;
}

match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(())) => Ok(Ok(socket)),
Ok(Err(e)) => Ok(Err((addr, e))),
Err(e) => Err((addr, e)),
}
};
join_set.spawn(fut);
}

// Wait for the next future that finishes in the join set, until we got one
// that resulted in a successful connection.
// If cnx is no longer None, we exit the loop, since this means that we got
// a successful connection.
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
match res? {
Ok(Ok(socket)) => {
// We've got a successful connection, so we can abort all other
// on-going attempts.
join_set.abort_all();

debug!(
"Connected to udp endpoint {}, aborted all other connection attempts",
socket.peer_addr()?
);
cnx = Some(socket);
break;
}
Ok(Err(err)) => {
debug!("Cannot connect udp socket to specified peer {addr} reason {err}");
Ok(Err((addr, err))) => {
debug!("Cannot connect to udp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err(_) => {
debug!(
"Cannot connect udp socket to specified peer {addr} due to timeout of {}s elapsed",
Err((addr, _)) => {
warn!(
"Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
);
}
Expand Down