Skip to content

Commit

Permalink
feat(dns): Add flag to specify if we should prefer IPv4 over IPv6
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Jul 20, 2024
1 parent 90d378e commit 711ceb9
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 58 deletions.
96 changes: 70 additions & 26 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ use std::time::Duration;
use tokio::net::{TcpStream, UdpSocket};
use url::{Host, Url};

// Interweave v4 and v6 addresses as per RFC8305.
// Interleave v4 and v6 addresses as per RFC8305.
// The first address is v6 if we have any v6 addresses.
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
let mut pick_v6 = false;
#[inline]
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator<Item = &'_ SocketAddr> {
let mut pick_v6 = !prefer_ipv6;
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
std::iter::from_fn(move || {
Expand All @@ -34,48 +35,62 @@ pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'
#[derive(Clone)]
pub enum DnsResolver {
System,
TrustDns(AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>),
TrustDns {
resolver: AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
prefer_ipv6: bool,
},
}

impl DnsResolver {
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
let addrs: Vec<SocketAddr> = match self {
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
Self::TrustDns(dns_resolver) => dns_resolver
.lookup_ip(domain)
.await?
.into_iter()
.map(|ip| match ip {
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
})
.collect(),
Self::TrustDns { resolver, prefer_ipv6 } => {
let addrs: Vec<_> = resolver
.lookup_ip(domain)
.await?
.into_iter()
.map(|ip| match ip {
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
})
.collect();
sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect()
}
};

Ok(addrs)
}

pub fn new_from_urls(resolvers: &[Url], proxy: Option<Url>, so_mark: Option<u32>) -> anyhow::Result<Self> {
pub fn new_from_urls(
resolvers: &[Url],
proxy: Option<Url>,
so_mark: Option<u32>,
prefer_ipv6: bool,
) -> anyhow::Result<Self> {
if resolvers.is_empty() {
// no dns resolver specified, fall-back to default one
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
warn!("Fall-backing to system dns resolver. You should consider specifying a dns resolver. To avoid performance issue");
return Ok(Self::System);
};

opts.timeout = std::time::Duration::from_secs(1);
opts.timeout = Duration::from_secs(1);
// Windows end-up with too many dns resolvers, which causes a performance issue
// https://github.com/hickory-dns/hickory-dns/issues/1968
#[cfg(target_os = "windows")]
{
opts.cache_size = 1024;
opts.num_concurrent_reqs = cfg.name_servers().len();
}
return Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
)));
return Ok(Self::TrustDns {
resolver: AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
),
prefer_ipv6,
});
};

// if one is specified as system, use the default one from libc
Expand Down Expand Up @@ -127,13 +142,16 @@ impl DnsResolver {
}

let mut opts = ResolverOpts::default();
opts.timeout = std::time::Duration::from_secs(1);
opts.timeout = Duration::from_secs(1);
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
)))
Ok(Self::TrustDns {
resolver: AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
),
prefer_ipv6,
})
}
}

Expand Down Expand Up @@ -235,3 +253,29 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
Box::pin(socket)
}
}

#[cfg(test)]
mod tests {
use crate::dns::sort_socket_addrs;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

#[test]
fn test_sort_socket_addrs() {
let addrs = [
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
];
let expected = [
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect();
assert_eq!(expected, *actual);
}
}
39 changes: 33 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ struct Client {
/// **WARN** On windows you may want to specify explicitly the DNS resolver to avoid excessive DNS queries
#[arg(long, verbatim_doc_comment)]
dns_resolver: Vec<Url>,

/// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6
/// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6
/// If you don't have any IPv6 this does not change anything.
#[arg(long, default_value = "false", verbatim_doc_comment)]
dns_resolver_prefer_ipv4: bool,
}

#[derive(clap::Args, Debug)]
Expand Down Expand Up @@ -295,6 +301,12 @@ struct Server {
#[arg(long, verbatim_doc_comment)]
dns_resolver: Vec<Url>,

/// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6
/// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6
/// If you don't have any IPv6 this does not change anything.
#[arg(long, default_value = "false", verbatim_doc_comment)]
dns_resolver_prefer_ipv4: bool,

/// Server will only accept connection from the specified tunnel information.
/// Can be specified multiple time
/// Example: --restrict-to "google.com:443" --restrict-to "localhost:22"
Expand Down Expand Up @@ -755,8 +767,13 @@ impl WsClientConfig {
#[tokio::main]
async fn main() {
let args = Wstunnel::parse();
let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await.unwrap();
socket.connect("[2001:4810:0:3::78]:443".parse::<SocketAddr>().unwrap()).await.unwrap();
let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
.await
.unwrap();
socket
.connect("[2001:4810:0:3::78]:443".parse::<SocketAddr>().unwrap())
.await
.unwrap();

// Setup logging
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
Expand Down Expand Up @@ -902,8 +919,13 @@ async fn main() {
websocket_mask_frame: args.websocket_mask_frame,
cnx_pool: None,
tls_reloader: None,
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, http_proxy.clone(), args.socket_so_mark)
.expect("cannot create dns resolver"),
dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver,
http_proxy.clone(),
args.socket_so_mark,
!args.dns_resolver_prefer_ipv4,
)
.expect("cannot create dns resolver"),
http_proxy,
};

Expand Down Expand Up @@ -1324,8 +1346,13 @@ async fn main() {
timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config,
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, None, args.socket_so_mark)
.expect("Cannot create DNS resolver"),
dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver,
None,
args.socket_so_mark,
!args.dns_resolver_prefer_ipv4,
)
.expect("Cannot create DNS resolver"),
restriction_config: args.restrict_config,
};

Expand Down
28 changes: 4 additions & 24 deletions src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::{anyhow, Context};
use std::{io, vec};
use tokio::task::JoinSet;

use crate::dns::{self, DnsResolver};
use crate::dns::DnsResolver;
use base64::Engine;
use bytes::BytesMut;
use log::warn;
Expand Down Expand Up @@ -73,7 +73,7 @@ pub async fn connect(
let mut last_err = None;
let mut join_set = JoinSet::new();

for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
for (ix, addr) in socket_addrs.into_iter().enumerate() {
debug!("Connecting to {}", addr);

let socket = match &addr {
Expand Down Expand Up @@ -107,7 +107,7 @@ pub async fn connect(
match res? {
Ok(Ok(stream)) => {
// We've got a successful connection, so we can abort all other
// on-going attempts.
// ongoing attempts.
join_set.abort_all();

debug!(
Expand Down Expand Up @@ -227,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
mod tests {
use super::*;
use futures_util::pin_mut;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::net::SocketAddr;
use testcontainers::core::WaitFor;
use testcontainers::runners::AsyncRunner;
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
Expand Down Expand Up @@ -259,26 +259,6 @@ mod tests {
}
}

#[test]
fn test_sort_socket_addrs() {
let addrs = [
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
];
let expected = [
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
assert_eq!(expected, *actual);
}

#[tokio::test]
async fn test_proxy_connection() {
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;

use crate::dns::{self, DnsResolver};
use crate::dns::DnsResolver;
use tokio::sync::Notify;
use tokio::time::{sleep, timeout, Interval};
use tracing::{debug, error, info};
Expand Down Expand Up @@ -340,7 +340,7 @@ pub async fn connect(
let mut last_err = None;
let mut join_set = JoinSet::new();

for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
for (ix, addr) in socket_addrs.into_iter().enumerate() {
debug!("connecting to {}", addr);

let socket = match &addr {
Expand Down

0 comments on commit 711ceb9

Please sign in to comment.