Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unix domain socket binding for reverse tunnel #362

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/protocols/unix_sock/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ impl Stream for UnixListenerStream {
pub async fn run_server(socket_path: &Path) -> Result<UnixListenerStream, anyhow::Error> {
info!("Starting Unix socket server listening cnx on {:?}", socket_path);

let path_to_delete = !socket_path.exists();
if socket_path.exists() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the unix socket path, already exist we should not delete it. As it is not owned by wstunnel.

std::fs::remove_file(socket_path)
.with_context(|| format!("Failed to delete existing Unix socket at {:?}", socket_path))?;
}

let listener = UnixListener::bind(socket_path)
.with_context(|| format!("Cannot create Unix socket server {:?}", socket_path))?;
let path_to_delete = true;

Ok(UnixListenerStream::new(listener, path_to_delete))
}
11 changes: 10 additions & 1 deletion src/tunnel/listeners/unix_sock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::protocols::unix_sock;
use crate::protocols::unix_sock::UnixListenerStream;
use crate::tunnel::{LocalProtocol, RemoteAddr};
use anyhow::{anyhow, Context};
use std::path::Path;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{ready, Poll};
use tokio::net::unix;
Expand All @@ -13,6 +13,7 @@ pub struct UnixTunnelListener {
listener: UnixListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
path: PathBuf,
}

impl UnixTunnelListener {
Expand All @@ -25,6 +26,7 @@ impl UnixTunnelListener {
listener,
dest,
proxy_protocol,
path: path.to_path_buf(),
})
}
}
Expand Down Expand Up @@ -55,3 +57,10 @@ impl Stream for UnixTunnelListener {
Poll::Ready(ret)
}
}
impl Drop for UnixTunnelListener {
fn drop(&mut self) {
if let Err(err) = std::fs::remove_file(&self.path) {
log::error!("Cannot remove Unix domain socket file {}: {}", self.path.display(), err);
}
}
}
6 changes: 6 additions & 0 deletions src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
},
}

#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub enum BindAddr {
Socket(SocketAddr),
Unix(String), // Unix socket path

Check warning on line 63 in src/tunnel/mod.rs

View workflow job for this annotation

GitHub Actions / Build - Windows x86_64

variant `Unix` is never constructed

Check warning on line 63 in src/tunnel/mod.rs

View workflow job for this annotation

GitHub Actions / Build - Windows x86

variant `Unix` is never constructed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use PathBuf instead of a string I think.

}

impl LocalProtocol {
pub const fn is_reverse_tunnel(&self) -> bool {
matches!(
Expand Down
8 changes: 4 additions & 4 deletions src/tunnel/server/reverse_tunnel.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::BindAddr;
use crate::tunnel::RemoteAddr;
use ahash::AHashMap;
use anyhow::anyhow;
use futures_util::{pin_mut, StreamExt};
use log::warn;
use parking_lot::Mutex;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -29,7 +29,7 @@ impl<T: TunnelListener> Clone for ReverseTunnelItem<T> {
}

pub struct ReverseTunnelServer<T: TunnelListener> {
servers: Arc<Mutex<AHashMap<SocketAddr, ReverseTunnelItem<T>>>>,
servers: Arc<Mutex<AHashMap<BindAddr, ReverseTunnelItem<T>>>>,
}

impl<T: TunnelListener> ReverseTunnelServer<T> {
Expand All @@ -41,7 +41,7 @@ impl<T: TunnelListener> ReverseTunnelServer<T> {

pub async fn run_listening_server(
&self,
bind_addr: SocketAddr,
bind_addr: BindAddr,
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
) -> anyhow::Result<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>
where
Expand All @@ -57,7 +57,7 @@ impl<T: TunnelListener> ReverseTunnelServer<T> {
let nb_seen_clients = Arc::new(AtomicUsize::new(0));
let seen_clients = nb_seen_clients.clone();
let server = self.servers.clone();
let local_srv2 = bind_addr;
let local_srv2 = bind_addr.clone();

let fut = async move {
scopeguard::defer!({
Expand Down
63 changes: 32 additions & 31 deletions src/tunnel/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::sync::{Arc, LazyLock};
use std::time::Duration;

use crate::protocols;
use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr};
use crate::tunnel::{try_to_sock_addr, BindAddr, LocalProtocol, RemoteAddr};
use hyper::body::Incoming;
use hyper::server::conn::{http1, http2};
use hyper::service::service_fn;
Expand Down Expand Up @@ -194,9 +194,10 @@ impl WsServer {
let header = ppp::v2::Builder::with_addresses(
ppp::v2::Version::Two | ppp::v2::Command::Proxy,
ppp::v2::Protocol::Stream,
(client_address, tx.local_addr()?),
(client_address, tx.local_addr().unwrap()),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

)
.build()?;
.build()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

.unwrap();
let _ = tx.write_all(&header).await;
}

Expand All @@ -210,8 +211,9 @@ impl WsServer {
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;

let ((local_rx, local_tx), remote) = SERVERS
.run_listening_server(BindAddr::Socket(bind), listening_server)
.await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseUdp { timeout } => {
Expand All @@ -222,7 +224,9 @@ impl WsServer {
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
let ((local_rx, local_tx), remote) = SERVERS
.run_listening_server(BindAddr::Socket(bind), listening_server)
.await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseSocks5 { timeout, credentials } => {
Expand All @@ -233,7 +237,9 @@ impl WsServer {
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
let ((local_rx, local_tx), remote) = SERVERS
.run_listening_server(BindAddr::Socket(bind), listening_server)
.await?;

Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
Expand All @@ -245,7 +251,9 @@ impl WsServer {
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
let ((local_rx, local_tx), remote) = SERVERS
.run_listening_server(BindAddr::Socket(bind), listening_server)
.await?;

Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
Expand All @@ -255,11 +263,10 @@ impl WsServer {
static SERVERS: LazyLock<ReverseTunnelServer<UnixTunnelListener>> =
LazyLock::new(ReverseTunnelServer::new);

let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { UnixTunnelListener::new(path, local_srv, false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
let listening_server = async { UnixTunnelListener::new(path, (remote.host, remote.port), false).await };
let ((local_rx, local_tx), remote) = SERVERS
.run_listening_server(BindAddr::Unix(path.to_str().unwrap().to_string()), listening_server)
.await?;

Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
Expand Down Expand Up @@ -291,7 +298,6 @@ impl WsServer {
move |req: Request<Incoming>| {
ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok)
.instrument(mk_span())
}
};

Expand All @@ -302,7 +308,6 @@ impl WsServer {
move |req: Request<Incoming>| {
http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok)
.instrument(mk_span())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

}
};

Expand Down Expand Up @@ -337,7 +342,6 @@ impl WsServer {
.unwrap())
}
}
.instrument(mk_span())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

}
};

Expand Down Expand Up @@ -383,12 +387,20 @@ impl WsServer {
}
};

let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

info!(parent: &span, "Accepting connection");
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) {
warn!("Error while configuring server socket {:?}", err);
}

let span = span!(
Level::INFO,
"tunnel",
id = tracing::field::Empty,
remote = tracing::field::Empty,
peer = peer_addr.to_string(),
forwarded_for = tracing::field::Empty
);

info!("Accepting connection");
let server = self.clone();
let restrictions = restrictions.restrictions_rules().clone();

Expand Down Expand Up @@ -435,9 +447,7 @@ impl WsServer {
mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr);
let conn_fut = http1::Builder::new()
.timer(TokioTimer::new())
// https://github.com/erebe/wstunnel/issues/358
// disabled, to avoid conflict with --connection-min-idle flag, that open idle connections
.header_read_timeout(None)
.header_read_timeout(Duration::from_secs(10))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad merge

.serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
.with_upgrades();

Expand All @@ -450,6 +460,7 @@ impl WsServer {
.instrument(span);

tokio::spawn(fut);
// Normal
}
// HTTP without TLS
None => {
Expand Down Expand Up @@ -477,16 +488,6 @@ impl WsServer {
}
}

fn mk_span() -> Span {
span!(
Level::INFO,
"tunnel",
id = tracing::field::Empty,
remote = tracing::field::Empty,
forwarded_for = tracing::field::Empty
)
}

impl Debug for WsServerConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("WsServerConfig")
Expand Down
Loading