From a33a889b3dd634ea9f880b86f176819b83336b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Mon, 29 Jul 2024 23:08:40 +0200 Subject: [PATCH] Refacto: Use proper type for WsClient --- src/main.rs | 185 +++++-------------------- src/protocols/tls/server.rs | 3 +- src/tunnel/client.rs | 176 ------------------------ src/tunnel/client/client.rs | 221 ++++++++++++++++++++++++++++++ src/tunnel/client/cnx_pool.rs | 74 ++++++++++ src/tunnel/client/config.rs | 76 ++++++++++ src/tunnel/client/mod.rs | 8 ++ src/tunnel/mod.rs | 54 +------- src/tunnel/tls_reloader.rs | 3 +- src/tunnel/transport/http2.rs | 58 ++++---- src/tunnel/transport/websocket.rs | 7 +- 11 files changed, 453 insertions(+), 412 deletions(-) delete mode 100644 src/tunnel/client.rs create mode 100644 src/tunnel/client/client.rs create mode 100644 src/tunnel/client/cnx_pool.rs create mode 100644 src/tunnel/client/config.rs create mode 100644 src/tunnel/client/mod.rs diff --git a/src/main.rs b/src/main.rs index 54f32d16..6d5b8c3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,15 +3,23 @@ mod protocols; mod restrictions; mod tunnel; +use crate::protocols::dns::DnsResolver; +use crate::protocols::tls; +use crate::restrictions::types::RestrictionsRules; +use crate::tunnel::client::{TlsClientConfig, WsClient, WsClientConfig}; +use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector}; +use crate::tunnel::listeners::{ + new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, +}; +use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use base64::Engine; use clap::Parser; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; use log::debug; -use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use std::io::ErrorKind; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; @@ -21,21 +29,8 @@ use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; use tokio::select; - -use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName}; -use tokio_rustls::TlsConnector; - +use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer}; use tracing::{error, info}; - -use crate::protocols::dns::DnsResolver; -use crate::protocols::tls; -use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector}; -use crate::tunnel::listeners::{ - new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, -}; -use crate::tunnel::tls_reloader::TlsReloader; -use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -695,22 +690,6 @@ fn parse_server_url(arg: &str) -> Result { Ok(url) } -#[derive(Clone)] -pub struct TlsClientConfig { - pub tls_sni_disabled: bool, - pub tls_sni_override: Option>, - pub tls_verify_certificate: bool, - tls_connector: Arc>, - pub tls_certificate_path: Option, - pub tls_key_path: Option, -} - -impl TlsClientConfig { - pub fn tls_connector(&self) -> TlsConnector { - self.tls_connector.read().clone() - } -} - #[derive(Debug)] pub struct TlsServerConfig { pub tls_certificate: Mutex>>, @@ -754,59 +733,6 @@ impl Debug for WsServerConfig { } } -#[derive(Clone)] -pub struct WsClientConfig { - pub remote_addr: TransportAddr, - pub socket_so_mark: Option, - pub http_upgrade_path_prefix: String, - pub http_upgrade_credentials: Option, - pub http_headers: HashMap, - pub http_headers_file: Option, - pub http_header_host: HeaderValue, - pub timeout_connect: Duration, - pub websocket_ping_frequency: Duration, - pub websocket_mask_frame: bool, - pub http_proxy: Option, - cnx_pool: Option>, - tls_reloader: Option>, - pub dns_resolver: DnsResolver, -} - -impl WsClientConfig { - pub const fn websocket_scheme(&self) -> &'static str { - match self.remote_addr.tls().is_some() { - false => "ws", - true => "wss", - } - } - - pub fn cnx_pool(&self) -> &bb8::Pool { - self.cnx_pool.as_ref().unwrap() - } - - pub fn websocket_host_url(&self) -> String { - format!("{}:{}", self.remote_addr.host(), self.remote_addr.port()) - } - - pub fn tls_server_name(&self) -> ServerName<'static> { - static INVALID_DNS_NAME: Lazy = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap()); - - self.remote_addr - .tls() - .and_then(|tls| tls.tls_sni_override.as_ref()) - .map_or_else( - || match &self.remote_addr.host() { - Host::Domain(domain) => ServerName::DnsName( - DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()), - ), - Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()), - Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()), - }, - |sni_override| ServerName::DnsName(sni_override.clone()), - ) - } -} - #[tokio::main] async fn main() -> anyhow::Result<()> { let args = Wstunnel::parse(); @@ -866,24 +792,7 @@ async fn main() -> anyhow::Result<()> { TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url"); let tls = match transport_scheme { TransportScheme::Ws | TransportScheme::Http => None, - TransportScheme::Wss => Some(TlsClientConfig { - tls_connector: Arc::new(RwLock::new( - tls::tls_connector( - args.tls_verify_certificate, - transport_scheme.alpn_protocols(), - !args.tls_sni_disable, - tls_certificate, - tls_key, - ) - .expect("Cannot create tls connector"), - )), - tls_sni_override: args.tls_sni_override, - tls_verify_certificate: args.tls_verify_certificate, - tls_sni_disabled: args.tls_sni_disable, - tls_certificate_path: args.tls_certificate.clone(), - tls_key_path: args.tls_private_key.clone(), - }), - TransportScheme::Https => Some(TlsClientConfig { + TransportScheme::Wss | TransportScheme::Https => Some(TlsClientConfig { tls_connector: Arc::new(RwLock::new( tls::tls_connector( args.tls_verify_certificate, @@ -936,7 +845,7 @@ async fn main() -> anyhow::Result<()> { } else { None }; - let mut client_config = WsClientConfig { + let client_config = WsClientConfig { remote_addr: TransportAddr::new( TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), args.remote_addr.host().unwrap().to_owned(), @@ -953,8 +862,6 @@ async fn main() -> anyhow::Result<()> { timeout_connect: Duration::from_secs(10), websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), websocket_mask_frame: args.websocket_mask_frame, - cnx_pool: None, - tls_reloader: None, dns_resolver: DnsResolver::new_from_urls( &args.dns_resolver, http_proxy.clone(), @@ -965,28 +872,16 @@ async fn main() -> anyhow::Result<()> { http_proxy, }; - let tls_reloader = - TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader"); - client_config.tls_reloader = Some(Arc::new(tls_reloader)); - let pool = bb8::Pool::builder() - .max_size(1000) - .min_idle(Some(args.connection_min_idle)) - .max_lifetime(Some(Duration::from_secs(30))) - .connection_timeout(args.connection_retry_max_backoff_sec) - .retry_connection(true) - .build(client_config.clone()) - .await - .unwrap(); - client_config.cnx_pool = Some(pool); - let client_config = Arc::new(client_config); + let client = + WsClient::new(client_config, args.connection_min_idle, args.connection_retry_max_backoff_sec).await?; // Start tunnels for tunnel in args.remote_to_local.into_iter() { - let client_config = client_config.clone(); + let client = client.clone(); match &tunnel.local_protocol { LocalProtocol::Tcp { proxy_protocol: _ } => { tokio::spawn(async move { - let cfg = client_config.clone(); + let cfg = client.config.clone(); let tcp_connector = TcpTunnelConnector::new( &tunnel.remote.0, tunnel.remote.1, @@ -1000,9 +895,7 @@ async fn main() -> anyhow::Result<()> { host, port, }; - if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await - { + if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await { error!("{:?}", err); } }); @@ -1011,7 +904,7 @@ async fn main() -> anyhow::Result<()> { let timeout = *timeout; tokio::spawn(async move { - let cfg = client_config.clone(); + let cfg = client.config.clone(); let (host, port) = to_host_port(tunnel.local); let remote = RemoteAddr { protocol: LocalProtocol::ReverseUdp { timeout }, @@ -1026,9 +919,7 @@ async fn main() -> anyhow::Result<()> { &cfg.dns_resolver, ); - if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await - { + if let Err(err) = client.run_reverse_tunnel(remote.clone(), udp_connector).await { error!("{:?}", err); } }); @@ -1037,7 +928,7 @@ async fn main() -> anyhow::Result<()> { let credentials = credentials.clone(); let timeout = *timeout; tokio::spawn(async move { - let cfg = client_config.clone(); + let cfg = client.config.clone(); let (host, port) = to_host_port(tunnel.local); let remote = RemoteAddr { protocol: LocalProtocol::ReverseSocks5 { timeout, credentials }, @@ -1047,9 +938,7 @@ async fn main() -> anyhow::Result<()> { let socks_connector = Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver); - if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await - { + if let Err(err) = client.run_reverse_tunnel(remote, socks_connector).await { error!("{:?}", err); } }); @@ -1060,7 +949,7 @@ async fn main() -> anyhow::Result<()> { let credentials = credentials.clone(); let timeout = *timeout; tokio::spawn(async move { - let cfg = client_config.clone(); + let cfg = client.config.clone(); let (host, port) = to_host_port(tunnel.local); let remote = RemoteAddr { protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials }, @@ -1075,9 +964,7 @@ async fn main() -> anyhow::Result<()> { &cfg.dns_resolver, ); - if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await - { + if let Err(err) = client.run_reverse_tunnel(remote.clone(), tcp_connector).await { error!("{:?}", err); } }); @@ -1086,7 +973,7 @@ async fn main() -> anyhow::Result<()> { LocalProtocol::Unix { path } => { let path = path.clone(); tokio::spawn(async move { - let cfg = client_config.clone(); + let cfg = client.config.clone(); let tcp_connector = TcpTunnelConnector::new( &tunnel.remote.0, tunnel.remote.1, @@ -1101,9 +988,7 @@ async fn main() -> anyhow::Result<()> { host, port, }; - if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await - { + if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await { error!("{:?}", err); } }); @@ -1126,14 +1011,14 @@ async fn main() -> anyhow::Result<()> { } for tunnel in args.local_to_remote.into_iter() { - let client_config = client_config.clone(); + let client = client.clone(); match &tunnel.local_protocol { LocalProtocol::Tcp { proxy_protocol } => { let server = TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1144,7 +1029,7 @@ async fn main() -> anyhow::Result<()> { let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1154,7 +1039,7 @@ async fn main() -> anyhow::Result<()> { use crate::tunnel::listeners::UnixTunnelListener; let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1169,7 +1054,7 @@ async fn main() -> anyhow::Result<()> { use crate::tunnel::listeners::new_tproxy_udp; let server = new_tproxy_udp(tunnel.local, *timeout).await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1182,7 +1067,7 @@ async fn main() -> anyhow::Result<()> { let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1190,7 +1075,7 @@ async fn main() -> anyhow::Result<()> { LocalProtocol::Socks5 { timeout, credentials } => { let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1204,7 +1089,7 @@ async fn main() -> anyhow::Result<()> { HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol) .await?; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); @@ -1213,7 +1098,7 @@ async fn main() -> anyhow::Result<()> { LocalProtocol::Stdio => { let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + if let Err(err) = client.run_tunnel(server).await { error!("{:?}", err); } }); diff --git a/src/protocols/tls/server.rs b/src/protocols/tls/server.rs index fdf7b348..5e3d24c6 100644 --- a/src/protocols/tls/server.rs +++ b/src/protocols/tls/server.rs @@ -1,4 +1,4 @@ -use crate::{TlsServerConfig, WsClientConfig}; +use crate::TlsServerConfig; use anyhow::{anyhow, Context}; use std::fs::File; @@ -9,6 +9,7 @@ use std::sync::Arc; use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; +use crate::tunnel::client::WsClientConfig; use crate::tunnel::TransportAddr; use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs deleted file mode 100644 index b2de8c40..00000000 --- a/src/tunnel/client.rs +++ /dev/null @@ -1,176 +0,0 @@ -use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; -use crate::tunnel::connectors::TunnelConnector; -use crate::tunnel::listeners::TunnelListener; -use crate::tunnel::transport::{TunnelReader, TunnelWriter}; -use crate::{tunnel, WsClientConfig}; -use futures_util::pin_mut; -use hyper::header::COOKIE; -use jsonwebtoken::TokenData; -use log::debug; -use std::ops::Deref; -use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::oneshot; -use tokio_stream::StreamExt; -use tracing::{error, event, span, Instrument, Level, Span}; -use url::Host; -use uuid::Uuid; - -async fn connect_to_server( - request_id: Uuid, - client_cfg: &WsClientConfig, - remote_cfg: &RemoteAddr, - duplex_stream: (R, W), -) -> anyhow::Result<()> -where - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - // Connect to server with the correct protocol - let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { - TransportScheme::Ws | TransportScheme::Wss => { - tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg) - .await - .map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))? - } - TransportScheme::Http | TransportScheme::Https => { - tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg) - .await - .map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))? - } - }; - - debug!("Server response: {:?}", response); - let (local_rx, local_tx) = duplex_stream; - let (close_tx, close_rx) = oneshot::channel::<()>(); - - // Forward local tx to websocket tx - let ping_frequency = client_cfg.websocket_ping_frequency; - tokio::spawn( - super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) - .instrument(Span::current()), - ); - - // Forward websocket rx to local rx - let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; - - Ok(()) -} - -pub async fn run_tunnel(client_config: Arc, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> { - pin_mut!(incoming_cnx); - while let Some(cnx) = incoming_cnx.next().await { - let (cnx_stream, remote_addr) = match cnx { - Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr), - Err(err) => { - error!("Error accepting connection: {:?}", err); - continue; - } - }; - - let request_id = Uuid::now_v7(); - let span = span!( - Level::INFO, - "tunnel", - id = request_id.to_string(), - remote = format!("{}:{}", remote_addr.host, remote_addr.port) - ); - let client_config = client_config.clone(); - - let tunnel = async move { - let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream) - .await - .map_err(|err| error!("{:?}", err)); - } - .instrument(span); - - tokio::spawn(tunnel); - } - - Ok(()) -} - -pub async fn run_reverse_tunnel( - client_cfg: Arc, - remote_addr: RemoteAddr, - connector: impl TunnelConnector, -) -> anyhow::Result<()> { - loop { - let client_config = client_cfg.clone(); - let request_id = Uuid::now_v7(); - let span = span!( - Level::INFO, - "tunnel", - id = request_id.to_string(), - remote = format!("{}:{}", remote_addr.host, remote_addr.port) - ); - // Correctly configure tunnel cfg - let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { - TransportScheme::Ws | TransportScheme::Wss => { - match tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr) - .instrument(span.clone()) - .await - { - Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response), - Err(err) => { - event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - continue; - } - } - } - TransportScheme::Http | TransportScheme::Https => { - match tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr) - .instrument(span.clone()) - .await - { - Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response), - Err(err) => { - event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - continue; - } - } - } - }; - - // Connect to endpoint - event!(parent: &span, Level::DEBUG, "Server response: {:?}", response); - let remote = response - .headers - .get(COOKIE) - .and_then(|h| h.to_str().ok()) - .and_then(|h| { - let (validation, decode_key) = JWT_DECODE.deref(); - let jwt: Option> = jsonwebtoken::decode(h, decode_key, validation).ok(); - jwt - }) - .map(|jwt| RemoteAddr { - protocol: jwt.claims.p, - host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), - port: jwt.claims.rp, - }); - - let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await { - Ok(s) => s, - Err(err) => { - event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}"); - continue; - } - }; - - let (close_tx, close_rx) = oneshot::channel::<()>(); - let tunnel = async move { - let ping_frequency = client_config.websocket_ping_frequency; - tokio::spawn( - super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) - .in_current_span(), - ); - - // Forward websocket rx to local rx - let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; - } - .instrument(span.clone()); - tokio::spawn(tunnel); - } -} diff --git a/src/tunnel/client/client.rs b/src/tunnel/client/client.rs new file mode 100644 index 00000000..7cd7bb3c --- /dev/null +++ b/src/tunnel/client/client.rs @@ -0,0 +1,221 @@ +use crate::tunnel; +use crate::tunnel::client::cnx_pool::WsConnection; +use crate::tunnel::client::WsClientConfig; +use crate::tunnel::connectors::TunnelConnector; +use crate::tunnel::listeners::TunnelListener; +use crate::tunnel::tls_reloader::TlsReloader; +use crate::tunnel::transport::{TunnelReader, TunnelWriter}; +use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; +use anyhow::Context; +use futures_util::pin_mut; +use hyper::header::COOKIE; +use jsonwebtoken::TokenData; +use log::debug; +use std::ops::Deref; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::oneshot; +use tokio_stream::StreamExt; +use tracing::{error, event, span, Instrument, Level, Span}; +use url::Host; +use uuid::Uuid; + +#[derive(Clone)] +pub struct WsClient { + pub config: Arc, + pub cnx_pool: bb8::Pool, + _tls_reloader: Arc, +} + +impl WsClient { + pub async fn new( + config: WsClientConfig, + connection_min_idle: u32, + connection_retry_max_backoff_sec: Duration, + ) -> anyhow::Result { + let config = Arc::new(config); + let cnx = WsConnection::new(config.clone()); + let tls_reloader = TlsReloader::new_for_client(config.clone()).with_context(|| "Cannot create tls reloader")?; + let cnx_pool = bb8::Pool::builder() + .max_size(1000) + .min_idle(Some(connection_min_idle)) + .max_lifetime(Some(Duration::from_secs(30))) + .connection_timeout(connection_retry_max_backoff_sec) + .retry_connection(true) + .build(cnx) + .await?; + + Ok(Self { + config, + cnx_pool, + _tls_reloader: Arc::new(tls_reloader), + }) + } +} + +impl WsClient { + async fn connect_to_server( + &self, + request_id: Uuid, + remote_cfg: &RemoteAddr, + duplex_stream: (R, W), + ) -> anyhow::Result<()> + where + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, + { + // Connect to server with the correct protocol + let (ws_rx, ws_tx, response) = match self.config.remote_addr.scheme() { + TransportScheme::Ws | TransportScheme::Wss => { + tunnel::transport::websocket::connect(request_id, self, remote_cfg) + .await + .map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))? + } + TransportScheme::Http | TransportScheme::Https => { + tunnel::transport::http2::connect(request_id, self, remote_cfg) + .await + .map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))? + } + }; + + debug!("Server response: {:?}", response); + let (local_rx, local_tx) = duplex_stream; + let (close_tx, close_rx) = oneshot::channel::<()>(); + + // Forward local tx to websocket tx + let ping_frequency = self.config.websocket_ping_frequency; + tokio::spawn( + super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) + .instrument(Span::current()), + ); + + // Forward websocket rx to local rx + let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; + + Ok(()) + } + + pub async fn run_tunnel(self, tunnel_listener: impl TunnelListener) -> anyhow::Result<()> { + pin_mut!(tunnel_listener); + while let Some(cnx) = tunnel_listener.next().await { + let (cnx_stream, remote_addr) = match cnx { + Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr), + Err(err) => { + error!("Error accepting connection: {:?}", err); + continue; + } + }; + + let request_id = Uuid::now_v7(); + let span = span!( + Level::INFO, + "tunnel", + id = request_id.to_string(), + remote = format!("{}:{}", remote_addr.host, remote_addr.port) + ); + let client = self.clone(); + let tunnel = async move { + let _ = client + .connect_to_server(request_id, &remote_addr, cnx_stream) + .await + .map_err(|err| error!("{:?}", err)); + } + .instrument(span); + + tokio::spawn(tunnel); + } + + Ok(()) + } + + pub async fn run_reverse_tunnel( + self, + remote_addr: RemoteAddr, + connector: impl TunnelConnector, + ) -> anyhow::Result<()> { + loop { + let client = self.clone(); + let request_id = Uuid::now_v7(); + let span = span!( + Level::INFO, + "tunnel", + id = request_id.to_string(), + remote = format!("{}:{}", remote_addr.host, remote_addr.port) + ); + // Correctly configure tunnel cfg + let (ws_rx, ws_tx, response) = match client.config.remote_addr.scheme() { + TransportScheme::Ws | TransportScheme::Wss => { + match tunnel::transport::websocket::connect(request_id, &client, &remote_addr) + .instrument(span.clone()) + .await + { + Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response), + Err(err) => { + event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + } + } + TransportScheme::Http | TransportScheme::Https => { + match tunnel::transport::http2::connect(request_id, &client, &remote_addr) + .instrument(span.clone()) + .await + { + Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response), + Err(err) => { + event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + } + } + }; + + // Connect to endpoint + event!(parent: &span, Level::DEBUG, "Server response: {:?}", response); + let remote = response + .headers + .get(COOKIE) + .and_then(|h| h.to_str().ok()) + .and_then(|h| { + let (validation, decode_key) = JWT_DECODE.deref(); + let jwt: Option> = jsonwebtoken::decode(h, decode_key, validation).ok(); + jwt + }) + .map(|jwt| RemoteAddr { + protocol: jwt.claims.p, + host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), + port: jwt.claims.rp, + }); + + let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await { + Ok(s) => s, + Err(err) => { + event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}"); + continue; + } + }; + + let (close_tx, close_rx) = oneshot::channel::<()>(); + let tunnel = async move { + let ping_frequency = client.config.websocket_ping_frequency; + tokio::spawn( + super::super::transport::io::propagate_local_to_remote( + local_rx, + ws_tx, + close_tx, + Some(ping_frequency), + ) + .in_current_span(), + ); + + // Forward websocket rx to local rx + let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; + } + .instrument(span.clone()); + tokio::spawn(tunnel); + } + } +} diff --git a/src/tunnel/client/cnx_pool.rs b/src/tunnel/client/cnx_pool.rs new file mode 100644 index 00000000..6a04985e --- /dev/null +++ b/src/tunnel/client/cnx_pool.rs @@ -0,0 +1,74 @@ +use crate::protocols; +use crate::protocols::tls; +use crate::tunnel::client::WsClientConfig; +use crate::tunnel::TransportStream; +use async_trait::async_trait; +use bb8::ManageConnection; +use std::ops::Deref; +use std::sync::Arc; +use tracing::instrument; + +#[derive(Clone)] +pub struct WsConnection(Arc); + +impl WsConnection { + pub fn new(config: Arc) -> Self { + Self(config) + } +} + +impl Deref for WsConnection { + type Target = WsClientConfig; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait] +impl ManageConnection for WsConnection { + type Connection = Option; + type Error = anyhow::Error; + + #[instrument(level = "trace", name = "cnx_server", skip_all)] + async fn connect(&self) -> Result { + let so_mark = self.socket_so_mark; + let timeout = self.timeout_connect; + + let tcp_stream = if let Some(http_proxy) = &self.http_proxy { + protocols::tcp::connect_with_http_proxy( + http_proxy, + self.remote_addr.host(), + self.remote_addr.port(), + so_mark, + timeout, + &self.dns_resolver, + ) + .await? + } else { + protocols::tcp::connect( + self.remote_addr.host(), + self.remote_addr.port(), + so_mark, + timeout, + &self.dns_resolver, + ) + .await? + }; + + if self.remote_addr.tls().is_some() { + let tls_stream = tls::connect(self, tcp_stream).await?; + Ok(Some(TransportStream::Tls(tls_stream))) + } else { + Ok(Some(TransportStream::Plain(tcp_stream))) + } + } + + async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { + Ok(()) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.is_none() + } +} diff --git a/src/tunnel/client/config.rs b/src/tunnel/client/config.rs new file mode 100644 index 00000000..ff3bf110 --- /dev/null +++ b/src/tunnel/client/config.rs @@ -0,0 +1,76 @@ +use crate::protocols::dns::DnsResolver; +use crate::tunnel::TransportAddr; +use hyper::header::{HeaderName, HeaderValue}; +use once_cell::sync::Lazy; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::net::IpAddr; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tokio_rustls::rustls::pki_types::{DnsName, ServerName}; +use tokio_rustls::TlsConnector; +use url::{Host, Url}; + +#[derive(Clone)] +pub struct WsClientConfig { + pub remote_addr: TransportAddr, + pub socket_so_mark: Option, + pub http_upgrade_path_prefix: String, + pub http_upgrade_credentials: Option, + pub http_headers: HashMap, + pub http_headers_file: Option, + pub http_header_host: HeaderValue, + pub timeout_connect: Duration, + pub websocket_ping_frequency: Duration, + pub websocket_mask_frame: bool, + pub http_proxy: Option, + pub dns_resolver: DnsResolver, +} + +impl WsClientConfig { + pub const fn websocket_scheme(&self) -> &'static str { + match self.remote_addr.tls().is_some() { + false => "ws", + true => "wss", + } + } + + pub fn websocket_host_url(&self) -> String { + format!("{}:{}", self.remote_addr.host(), self.remote_addr.port()) + } + + pub fn tls_server_name(&self) -> ServerName<'static> { + static INVALID_DNS_NAME: Lazy = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap()); + + self.remote_addr + .tls() + .and_then(|tls| tls.tls_sni_override.as_ref()) + .map_or_else( + || match &self.remote_addr.host() { + Host::Domain(domain) => ServerName::DnsName( + DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()), + ), + Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()), + Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()), + }, + |sni_override| ServerName::DnsName(sni_override.clone()), + ) + } +} + +#[derive(Clone)] +pub struct TlsClientConfig { + pub tls_sni_disabled: bool, + pub tls_sni_override: Option>, + pub tls_verify_certificate: bool, + pub tls_connector: Arc>, + pub tls_certificate_path: Option, + pub tls_key_path: Option, +} + +impl TlsClientConfig { + pub fn tls_connector(&self) -> TlsConnector { + self.tls_connector.read().clone() + } +} diff --git a/src/tunnel/client/mod.rs b/src/tunnel/client/mod.rs new file mode 100644 index 00000000..c001fec3 --- /dev/null +++ b/src/tunnel/client/mod.rs @@ -0,0 +1,8 @@ +#![allow(clippy::module_inception)] +mod client; +mod cnx_pool; +mod config; + +pub use client::WsClient; +pub use config::TlsClientConfig; +pub use config::WsClientConfig; diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index dfc3e4c3..7d96978d 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -5,10 +5,7 @@ pub mod server; pub mod tls_reloader; mod transport; -use crate::protocols::tls; -use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig}; -use async_trait::async_trait; -use bb8::ManageConnection; +use crate::{LocalProtocol, TlsClientConfig}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; @@ -23,7 +20,6 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; -use tracing::instrument; use url::Host; use uuid::Uuid; @@ -304,54 +300,6 @@ impl AsyncWrite for TransportStream { } } -#[async_trait] -impl ManageConnection for WsClientConfig { - type Connection = Option; - type Error = anyhow::Error; - - #[instrument(level = "trace", name = "cnx_server", skip_all)] - async fn connect(&self) -> Result { - let so_mark = self.socket_so_mark; - let timeout = self.timeout_connect; - - let tcp_stream = if let Some(http_proxy) = &self.http_proxy { - protocols::tcp::connect_with_http_proxy( - http_proxy, - self.remote_addr.host(), - self.remote_addr.port(), - so_mark, - timeout, - &self.dns_resolver, - ) - .await? - } else { - protocols::tcp::connect( - self.remote_addr.host(), - self.remote_addr.port(), - so_mark, - timeout, - &self.dns_resolver, - ) - .await? - }; - - if self.remote_addr.tls().is_some() { - let tls_stream = tls::connect(self, tcp_stream).await?; - Ok(Some(TransportStream::Tls(tls_stream))) - } else { - Ok(Some(TransportStream::Plain(tcp_stream))) - } - } - - async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { - Ok(()) - } - - fn has_broken(&self, conn: &mut Self::Connection) -> bool { - conn.is_none() - } -} - pub fn to_host_port(addr: SocketAddr) -> (Host, u16) { match addr.ip() { IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()), diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs index 0573099d..8af275b5 100644 --- a/src/tunnel/tls_reloader.rs +++ b/src/tunnel/tls_reloader.rs @@ -1,6 +1,7 @@ use crate::protocols::tls; +use crate::tunnel::client::WsClientConfig; use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server}; -use crate::{WsClientConfig, WsServerConfig}; +use crate::WsServerConfig; use anyhow::Context; use log::trace; use notify::{EventKind, RecommendedWatcher, Watcher}; diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs index e0967adf..951f2dfd 100644 --- a/src/tunnel/transport/http2.rs +++ b/src/tunnel/transport/http2.rs @@ -1,6 +1,6 @@ +use crate::tunnel::client::WsClient; use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme}; -use crate::WsClientConfig; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; use http_body_util::{BodyExt, BodyStream, StreamBody}; @@ -99,55 +99,57 @@ impl TunnelWrite for Http2TunnelWrite { pub async fn connect( request_id: Uuid, - client_cfg: &WsClientConfig, + client: &WsClient, dest_addr: &RemoteAddr, ) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> { - let mut pooled_cnx = match client_cfg.cnx_pool().get().await { + let mut pooled_cnx = match client.cnx_pool.get().await { Ok(cnx) => Ok(cnx), Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), }?; // In http2 HOST header does not exist, it is explicitly set in the authority from the request uri - let (headers_file, authority) = client_cfg - .http_headers_file - .as_ref() - .map_or((None, None), |headers_file_path| { - let (host, headers) = headers_from_file(headers_file_path); - let host = if let Some((_, v)) = host { - match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) { - (TransportScheme::Http, 80) | (TransportScheme::Https, 443) => { - Some(v.to_str().unwrap_or("").to_string()) + let (headers_file, authority) = + client + .config + .http_headers_file + .as_ref() + .map_or((None, None), |headers_file_path| { + let (host, headers) = headers_from_file(headers_file_path); + let host = if let Some((_, v)) = host { + match (client.config.remote_addr.scheme(), client.config.remote_addr.port()) { + (TransportScheme::Http, 80) | (TransportScheme::Https, 443) => { + Some(v.to_str().unwrap_or("").to_string()) + } + (_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)), } - (_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)), - } - } else { - None - }; + } else { + None + }; - (Some(headers), host) - }); + (Some(headers), host) + }); let mut req = Request::builder() .method("POST") .uri(format!( "{}://{}/{}/events", - client_cfg.remote_addr.scheme(), + client.config.remote_addr.scheme(), authority .as_deref() - .unwrap_or_else(|| client_cfg.http_header_host.to_str().unwrap_or("")), - &client_cfg.http_upgrade_path_prefix + .unwrap_or_else(|| client.config.http_header_host.to_str().unwrap_or("")), + &client.config.http_upgrade_path_prefix )) .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr)) .header(CONTENT_TYPE, "application/json") .version(hyper::Version::HTTP_2); let headers = req.headers_mut().unwrap(); - for (k, v) in &client_cfg.http_headers { + for (k, v) in &client.config.http_headers { let _ = headers.remove(k); headers.append(k, v.clone()); } - if let Some(auth) = &client_cfg.http_upgrade_credentials { + if let Some(auth) = &client.config.http_upgrade_credentials { let _ = headers.remove(AUTHORIZATION); headers.append(AUTHORIZATION, auth.clone()); } @@ -164,7 +166,7 @@ pub async fn connect( let req = req.body(body).with_context(|| { format!( "failed to build HTTP request to contact the server {:?}", - client_cfg.remote_addr + client.config.remote_addr ) })?; debug!("with HTTP upgrade request {:?}", req); @@ -172,11 +174,11 @@ pub async fn connect( let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) .adaptive_window(true) - .keep_alive_interval(client_cfg.websocket_ping_frequency) + .keep_alive_interval(client.config.websocket_ping_frequency) .keep_alive_while_idle(false) .handshake(TokioIo::new(transport)) .await - .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?; + .with_context(|| format!("failed to do http2 handshake with the server {:?}", client.config.remote_addr))?; tokio::spawn(async move { if let Err(err) = cnx.await { error!("{:?}", err) @@ -186,7 +188,7 @@ pub async fn connect( let response = request_sender .send_request(req) .await - .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?; + .with_context(|| format!("failed to send http2 request with the server {:?}", client.config.remote_addr))?; if !response.status().is_success() { return Err(anyhow!( diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index f49a12e7..f1ebe9d1 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,6 +1,6 @@ +use crate::tunnel::client::WsClient; use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; -use crate::WsClientConfig; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; @@ -135,10 +135,11 @@ impl TunnelRead for WebsocketTunnelRead { pub async fn connect( request_id: Uuid, - client_cfg: &WsClientConfig, + client: &WsClient, dest_addr: &RemoteAddr, ) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> { - let mut pooled_cnx = match client_cfg.cnx_pool().get().await { + let client_cfg = &client.config; + let mut pooled_cnx = match client.cnx_pool.get().await { Ok(cnx) => Ok(cnx), Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), }?;