Skip to content

Commit

Permalink
Refacto: Use proper type for WsClient
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Jul 29, 2024
1 parent 5e74ed2 commit 6f96808
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 146 deletions.
118 changes: 25 additions & 93 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use clap::Parser;
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use log::debug;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;
use std::fmt::{Debug, Formatter};
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
Expand All @@ -22,14 +21,15 @@ use std::time::Duration;
use std::{fmt, io};
use tokio::select;

use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer};
use tokio_rustls::TlsConnector;

use tracing::{error, info};

use crate::protocols::dns::DnsResolver;
use crate::protocols::tls;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::client::{WsClient, WsClientConfig};
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
use crate::tunnel::listeners::{
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
Expand Down Expand Up @@ -754,59 +754,6 @@ impl Debug for WsServerConfig {
}
}

#[derive(Clone)]
pub struct WsClientConfig {
pub remote_addr: TransportAddr,
pub socket_so_mark: Option<u32>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>,
pub http_headers_file: Option<PathBuf>,
pub http_header_host: HeaderValue,
pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
pub http_proxy: Option<Url>,
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
tls_reloader: Option<Arc<TlsReloader>>,
pub dns_resolver: DnsResolver,
}

impl WsClientConfig {
pub const fn websocket_scheme(&self) -> &'static str {
match self.remote_addr.tls().is_some() {
false => "ws",
true => "wss",
}
}

pub fn cnx_pool(&self) -> &bb8::Pool<Self> {
self.cnx_pool.as_ref().unwrap()
}

pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
}

pub fn tls_server_name(&self) -> ServerName<'static> {
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());

self.remote_addr
.tls()
.and_then(|tls| tls.tls_sni_override.as_ref())
.map_or_else(
|| match &self.remote_addr.host() {
Host::Domain(domain) => ServerName::DnsName(
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
),
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
},
|sni_override| ServerName::DnsName(sni_override.clone()),
)
}
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Wstunnel::parse();
Expand Down Expand Up @@ -953,7 +900,6 @@ async fn main() -> anyhow::Result<()> {
timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame,
cnx_pool: None,
tls_reloader: None,
dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver,
Expand All @@ -968,25 +914,16 @@ async fn main() -> anyhow::Result<()> {
let tls_reloader =
TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader");
client_config.tls_reloader = Some(Arc::new(tls_reloader));
let pool = bb8::Pool::builder()
.max_size(1000)
.min_idle(Some(args.connection_min_idle))
.max_lifetime(Some(Duration::from_secs(30)))
.connection_timeout(args.connection_retry_max_backoff_sec)
.retry_connection(true)
.build(client_config.clone())
.await
.unwrap();
client_config.cnx_pool = Some(pool);
let client_config = Arc::new(client_config);
let client =
WsClient::new(client_config, args.connection_min_idle, args.connection_retry_max_backoff_sec).await?;

// Start tunnels
for tunnel in args.remote_to_local.into_iter() {
let client_config = client_config.clone();
let client = client.clone();
match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol: _ } => {
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new(
&tunnel.remote.0,
tunnel.remote.1,
Expand All @@ -1000,9 +937,7 @@ async fn main() -> anyhow::Result<()> {
host,
port,
};
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
{
if let Err(err) = tunnel::client::run_reverse_tunnel(client, remote, tcp_connector).await {
error!("{:?}", err);
}
});
Expand All @@ -1011,7 +946,7 @@ async fn main() -> anyhow::Result<()> {
let timeout = *timeout;

tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseUdp { timeout },
Expand All @@ -1027,7 +962,7 @@ async fn main() -> anyhow::Result<()> {
);

if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await
tunnel::client::run_reverse_tunnel(client, remote.clone(), udp_connector).await
{
error!("{:?}", err);
}
Expand All @@ -1037,7 +972,7 @@ async fn main() -> anyhow::Result<()> {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseSocks5 { timeout, credentials },
Expand All @@ -1047,8 +982,7 @@ async fn main() -> anyhow::Result<()> {
let socks_connector =
Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver);

if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await
if let Err(err) = tunnel::client::run_reverse_tunnel(client, remote, socks_connector).await
{
error!("{:?}", err);
}
Expand All @@ -1060,7 +994,7 @@ async fn main() -> anyhow::Result<()> {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials },
Expand All @@ -1076,7 +1010,7 @@ async fn main() -> anyhow::Result<()> {
);

if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await
tunnel::client::run_reverse_tunnel(client, remote.clone(), tcp_connector).await
{
error!("{:?}", err);
}
Expand All @@ -1086,7 +1020,7 @@ async fn main() -> anyhow::Result<()> {
LocalProtocol::Unix { path } => {
let path = path.clone();
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new(
&tunnel.remote.0,
tunnel.remote.1,
Expand All @@ -1101,9 +1035,7 @@ async fn main() -> anyhow::Result<()> {
host,
port,
};
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
{
if let Err(err) = tunnel::client::run_reverse_tunnel(client, remote, tcp_connector).await {
error!("{:?}", err);
}
});
Expand All @@ -1126,14 +1058,14 @@ async fn main() -> anyhow::Result<()> {
}

for tunnel in args.local_to_remote.into_iter() {
let client_config = client_config.clone();
let client = client.clone();

match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol } => {
let server =
TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1144,7 +1076,7 @@ async fn main() -> anyhow::Result<()> {
let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;

tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1154,7 +1086,7 @@ async fn main() -> anyhow::Result<()> {
use crate::tunnel::listeners::UnixTunnelListener;
let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1169,7 +1101,7 @@ async fn main() -> anyhow::Result<()> {
use crate::tunnel::listeners::new_tproxy_udp;
let server = new_tproxy_udp(tunnel.local, *timeout).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1182,15 +1114,15 @@ async fn main() -> anyhow::Result<()> {
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;

tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
}
LocalProtocol::Socks5 { timeout, credentials } => {
let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1204,7 +1136,7 @@ async fn main() -> anyhow::Result<()> {
HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
.await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand All @@ -1213,7 +1145,7 @@ async fn main() -> anyhow::Result<()> {
LocalProtocol::Stdio => {
let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = tunnel::client::run_tunnel(client, server).await {
error!("{:?}", err);
}
});
Expand Down
3 changes: 2 additions & 1 deletion src/protocols/tls/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{TlsServerConfig, WsClientConfig};
use crate::TlsServerConfig;
use anyhow::{anyhow, Context};
use std::fs::File;

Expand All @@ -9,6 +9,7 @@ use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;

use crate::tunnel::client::WsClientConfig;
use crate::tunnel::TransportAddr;
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
Expand Down
7 changes: 7 additions & 0 deletions src/tunnel/client/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod types;
mod utils;

pub use types::WsClient;
pub use types::WsClientConfig;
pub use utils::run_reverse_tunnel;
pub use utils::run_tunnel;
Loading

0 comments on commit 6f96808

Please sign in to comment.