Skip to content

Commit

Permalink
cleanup argument parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Aug 1, 2024
1 parent f149b81 commit 811a1e6
Showing 1 changed file with 131 additions and 100 deletions.
231 changes: 131 additions & 100 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct Client {
/// 'socks5://[::1]:1212' => listen on server for incoming socks5 request on port 1212 and forward dynamically request from local machine (login/password is supported)
/// 'http://[::1]:1212' => listen on server for incoming http proxy request on port 1212 and forward dynamically request from local machine (login/password is supported)
/// 'unix://wstunnel.sock:g.com:443' => listen on server for incoming data from unix socket of path wstunnel.sock and forward to g.com:443 from local machine
#[arg(short='R', long, value_name = "{tcp,udp,socks5,unix}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)]
#[arg(short='R', long, value_name = "{tcp,udp,socks5,unix}://[BIND:]PORT:HOST:PORT", value_parser = parse_reverse_tunnel_arg, verbatim_doc_comment)]
remote_to_local: Vec<LocalToRemote>,

/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
Expand Down Expand Up @@ -468,35 +468,53 @@ fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<Str

fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
use std::io::Error;
let get_timeout = |options: &BTreeMap<String, String>| {
options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)))
};
let get_credentials = |options: &BTreeMap<String, String>| {
options
.get("login")
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string())))
};
let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol");

let Some((proto, tunnel_info)) = arg.split_once("://") else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse protocol from {}", arg),
));
};

match &arg[..6] {
"tcp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
match proto {
"tcp" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let proxy_protocol = options.contains_key("proxy_protocol");
Ok(LocalToRemote {
local_protocol: LocalProtocol::Tcp { proxy_protocol },
local_protocol: LocalProtocol::Tcp {
proxy_protocol: get_proxy_protocol(&options),
},
local: local_bind,
remote: (dest_host, dest_port),
})
}
"udp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
"udp" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));

Ok(LocalToRemote {
local_protocol: LocalProtocol::Udp { timeout },
local_protocol: LocalProtocol::Udp {
timeout: get_timeout(&options),
},
local: local_bind,
remote: (dest_host, dest_port),
})
}
"unix:/" => {
let Some((path, remote)) = arg[7..].split_once(':') else {
"unix" => {
let Some((path, remote)) = tunnel_info.split_once(':') else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse unix socket path from {}", arg),
Expand All @@ -511,89 +529,104 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
remote: (dest_host, dest_port),
})
}
"http:/" => {
let (local_bind, remaining) = parse_local_bind(&arg["http://".len()..])?;
"http" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let proxy_protocol = options.contains_key("proxy_protocol");
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
let credentials = options
.get("login")
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string())));
Ok(LocalToRemote {
local_protocol: LocalProtocol::HttpProxy {
timeout,
credentials,
proxy_protocol,
timeout: get_timeout(&options),
credentials: get_credentials(&options),
proxy_protocol: get_proxy_protocol(&options),
},
local: local_bind,
remote: (dest_host, dest_port),
})
}
_ => match &arg[..8] {
"socks5:/" => {
let (local_bind, remaining) = parse_local_bind(&arg["socks5://".len()..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
let credentials = options
.get("login")
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string())));
Ok(LocalToRemote {
local_protocol: LocalProtocol::Socks5 { timeout, credentials },
local: local_bind,
remote: (dest_host, dest_port),
})
}
"stdio://" => {
let (dest_host, dest_port, _options) = parse_tunnel_dest(&arg["stdio://".len()..])?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::Stdio,
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
remote: (dest_host, dest_port),
})
}
"tproxy+t" => {
let (local_bind, remaining) = parse_local_bind(&arg["tproxy+tcp://".len()..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::TProxyTcp,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"tproxy+u" => {
let (local_bind, remaining) = parse_local_bind(&arg["tproxy+udp://".len()..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote {
local_protocol: LocalProtocol::TProxyUdp { timeout },
local: local_bind,
remote: (dest_host, dest_port),
})
}
_ => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid local protocol for tunnel {}", arg),
)),
},
"socks5" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::Socks5 {
timeout: get_timeout(&options),
credentials: get_credentials(&options),
},
local: local_bind,
remote: (dest_host, dest_port),
})
}
"stdio" => {
let (dest_host, dest_port, _options) = parse_tunnel_dest(tunnel_info)?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::Stdio,
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
remote: (dest_host, dest_port),
})
}
"tproxy+tcp" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::TProxyTcp,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"tproxy+udp" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
local_protocol: LocalProtocol::TProxyUdp {
timeout: get_timeout(&options),
},
local: local_bind,
remote: (dest_host, dest_port),
})
}
_ => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid local protocol for tunnel {}", arg),
)),
}
}

fn parse_reverse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
let proto = parse_tunnel_arg(arg)?;
let local_protocol = match proto.local_protocol {
LocalProtocol::Tcp { .. } => LocalProtocol::ReverseTcp {},
LocalProtocol::Udp { timeout } => LocalProtocol::ReverseUdp { timeout },
LocalProtocol::Socks5 { timeout, credentials } => LocalProtocol::ReverseSocks5 { timeout, credentials },
LocalProtocol::HttpProxy {
timeout,
credentials,
proxy_protocol: _proxy_protocol,
} => LocalProtocol::ReverseHttpProxy { timeout, credentials },
LocalProtocol::Unix { path } => LocalProtocol::ReverseUnix { path },
LocalProtocol::ReverseTcp { .. }
| LocalProtocol::ReverseUdp { .. }
| LocalProtocol::ReverseSocks5 { .. }
| LocalProtocol::ReverseHttpProxy { .. }
| LocalProtocol::ReverseUnix { .. }
| LocalProtocol::TProxyTcp
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Stdio => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("Cannot use {:?} as reverse tunnels {}", proto.local_protocol, arg),
))
}
};

Ok(LocalToRemote {
local_protocol,
local: proto.local,
remote: proto.remote,
})
}

fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> {
match DnsName::try_from(arg.to_string()) {
Ok(val) => Ok(val),
Expand Down Expand Up @@ -788,7 +821,7 @@ async fn main() -> anyhow::Result<()> {
for tunnel in args.remote_to_local.into_iter() {
let client = client.clone();
match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol: _ } => {
LocalProtocol::ReverseTcp { .. } => {
tokio::spawn(async move {
let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new(
Expand All @@ -809,7 +842,7 @@ async fn main() -> anyhow::Result<()> {
}
});
}
LocalProtocol::Udp { timeout } => {
LocalProtocol::ReverseUdp { timeout } => {
let timeout = *timeout;

tokio::spawn(async move {
Expand All @@ -833,7 +866,7 @@ async fn main() -> anyhow::Result<()> {
}
});
}
LocalProtocol::Socks5 { timeout, credentials } => {
LocalProtocol::ReverseSocks5 { timeout, credentials } => {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
Expand All @@ -852,9 +885,7 @@ async fn main() -> anyhow::Result<()> {
}
});
}
LocalProtocol::HttpProxy {
timeout, credentials, ..
} => {
LocalProtocol::ReverseHttpProxy { timeout, credentials } => {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
Expand All @@ -879,7 +910,7 @@ async fn main() -> anyhow::Result<()> {
});
}
#[cfg(unix)]
LocalProtocol::Unix { path } => {
LocalProtocol::ReverseUnix { path } => {
let path = path.clone();
tokio::spawn(async move {
let cfg = client.config.clone();
Expand All @@ -903,17 +934,17 @@ async fn main() -> anyhow::Result<()> {
});
}
#[cfg(not(unix))]
LocalProtocol::Unix { .. } => {
LocalProtocol::ReverseUnix { .. } => {
panic!("Unix socket is not available for non Unix platform")
}
LocalProtocol::Stdio
| LocalProtocol::TProxyTcp
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::ReverseTcp
| LocalProtocol::ReverseUdp { .. }
| LocalProtocol::ReverseSocks5 { .. }
| LocalProtocol::ReverseHttpProxy { .. } => {}
LocalProtocol::ReverseUnix { .. } => {
| LocalProtocol::Tcp { .. }
| LocalProtocol::Udp { .. }
| LocalProtocol::Socks5 { .. }
| LocalProtocol::HttpProxy { .. } => {}
LocalProtocol::Unix { .. } => {
panic!("Invalid protocol for reverse tunnel");
}
}
Expand Down

0 comments on commit 811a1e6

Please sign in to comment.