Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Aug 5, 2024
1 parent a468428 commit 65f6884
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 162 deletions.
13 changes: 4 additions & 9 deletions src/tunnel/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ use crate::tunnel::client::WsClientConfig;
use crate::tunnel::connectors::TunnelConnector;
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
use crate::tunnel::transport::io::{TunnelReader, TunnelWriter};
use crate::tunnel::transport::jwt_token_to_tunnel;
use crate::tunnel::{RemoteAddr, TransportScheme};
use anyhow::Context;
use futures_util::pin_mut;
use hyper::header::COOKIE;
use jsonwebtoken::TokenData;
use log::debug;
use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -179,11 +178,7 @@ impl WsClient {
.headers
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|h| {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
jwt
})
.and_then(|h| jwt_token_to_tunnel(h).ok())
.map(|jwt| RemoteAddr {
protocol: jwt.claims.p,
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
Expand Down
70 changes: 4 additions & 66 deletions src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,13 @@ mod tls_reloader;
mod transport;

use crate::TlsClientConfig;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use url::Host;
use uuid::Uuid;

#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String, // tunnel id
pub p: LocalProtocol, // protocol to use
pub r: String, // remote host
pub rp: u16, // remote port
}

impl JwtTunnelConfig {
fn new(request_id: Uuid, dest: &RemoteAddr) -> Self {
Self {
id: request_id.to_string(),
p: match dest.protocol {
LocalProtocol::Tcp { .. } => dest.protocol.clone(),
LocalProtocol::Udp { .. } => dest.protocol.clone(),
LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::HttpProxy { .. } => dest.protocol.clone(),
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(),
LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(),
LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(),
LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(),
},
r: dest.host.to_string(),
rp: dest.port,
}
}
}

fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}

static JWT_HEADER_PREFIX: &str = "authorization.bearer.";
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> =
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));

static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});

#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum LocalProtocol {
Expand Down Expand Up @@ -122,6 +67,10 @@ impl LocalProtocol {
| Self::ReverseHttpProxy { .. }
)
}

pub const fn is_dynamic_reverse_tunnel(&self) -> bool {
matches!(self, |Self::ReverseSocks5 { .. }| Self::ReverseHttpProxy { .. })
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -285,17 +234,6 @@ impl TransportAddr {
}
}

impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
Ok(Self {
protocol: jwt.p,
host: Host::parse(&jwt.r)?,
port: jwt.rp,
})
}
}

pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
match addr.ip() {
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),
Expand Down
5 changes: 1 addition & 4 deletions src/tunnel/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ impl WsServer {
};

let req_protocol = remote.protocol.clone();
let inject_cookie = matches!(
req_protocol,
LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. }
);
let inject_cookie = req_protocol.is_dynamic_reverse_tunnel();
let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await {
Ok(ret) => ret,
Err(err) => {
Expand Down
7 changes: 3 additions & 4 deletions src/tunnel/server/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
};
use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::Either;
Expand All @@ -11,7 +12,6 @@ use hyper::{http, Request, Response, StatusCode};
use jsonwebtoken::TokenData;
use std::cmp::min;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::ops::Deref;
use tracing::{error, info, warn};
use url::Host;
use uuid::Uuid;
Expand Down Expand Up @@ -92,8 +92,7 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
.unwrap_or_default();

let (validation, decode_key) = JWT_DECODE.deref();
let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) {
let jwt = match jwt_token_to_tunnel(jwt) {
Ok(jwt) => jwt,
err => {
warn!(
Expand Down
6 changes: 4 additions & 2 deletions src/tunnel/transport/http2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::client::WsClient;
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme};
use crate::tunnel::transport::headers_from_file;
use crate::tunnel::transport::jwt::tunnel_to_jwt_token;
use crate::tunnel::{RemoteAddr, TransportScheme};
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use http_body_util::{BodyExt, BodyStream, StreamBody};
Expand Down
71 changes: 69 additions & 2 deletions src/tunnel/transport/io.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use bytes::BufMut;
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use bytes::{BufMut, BytesMut};
use futures_util::{pin_mut, FutureExt};
use std::future::Future;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::select;
Expand All @@ -9,6 +11,71 @@ use tokio::time::Instant;
use tracing::log::debug;
use tracing::{error, info, warn};

pub(super) static MAX_PACKET_LENGTH: usize = 64 * 1024;

pub trait TunnelWrite: Send + 'static {
fn buf_mut(&mut self) -> &mut BytesMut;
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}

pub trait TunnelRead: Send + 'static {
fn copy(
&mut self,
writer: impl AsyncWrite + Unpin + Send,
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}

pub enum TunnelReader {
Websocket(WebsocketTunnelRead),
Http2(Http2TunnelRead),
}

impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.copy(writer).await,
Self::Http2(s) => s.copy(writer).await,
}
}
}

pub enum TunnelWriter {
Websocket(WebsocketTunnelWrite),
Http2(Http2TunnelWrite),
}

impl TunnelWrite for TunnelWriter {
fn buf_mut(&mut self) -> &mut BytesMut {
match self {
Self::Websocket(s) => s.buf_mut(),
Self::Http2(s) => s.buf_mut(),
}
}

async fn write(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.write().await,
Self::Http2(s) => s.write().await,
}
}

async fn ping(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.ping().await,
Self::Http2(s) => s.ping().await,
}
}

async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.close().await,
Self::Http2(s) => s.close().await,
}
}
}

pub async fn propagate_local_to_remote(
local_rx: impl AsyncRead,
mut ws_tx: impl TunnelWrite,
Expand Down
75 changes: 75 additions & 0 deletions src/tunnel/transport/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::tunnel::{LocalProtocol, RemoteAddr};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::ops::Deref;
use url::Host;
use uuid::Uuid;

pub static JWT_HEADER_PREFIX: &str = "authorization.bearer.";
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> =
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));

static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtTunnelConfig {
pub id: String, // tunnel id
pub p: LocalProtocol, // protocol to use
pub r: String, // remote host
pub rp: u16, // remote port
}

impl JwtTunnelConfig {
fn new(request_id: Uuid, dest: &RemoteAddr) -> Self {
Self {
id: request_id.to_string(),
p: match dest.protocol {
LocalProtocol::Tcp { .. } => dest.protocol.clone(),
LocalProtocol::Udp { .. } => dest.protocol.clone(),
LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::HttpProxy { .. } => dest.protocol.clone(),
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(),
LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(),
LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(),
LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(),
},
r: dest.host.to_string(),
rp: dest.port,
}
}
}

pub fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}

pub fn jwt_token_to_tunnel(token: &str) -> anyhow::Result<TokenData<JwtTunnelConfig>> {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: TokenData<JwtTunnelConfig> = jsonwebtoken::decode(token, decode_key, validation)?;
Ok(jwt)
}

impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
Ok(Self {
protocol: jwt.p,
host: Host::parse(&jwt.r)?,
port: jwt.rp,
})
}
}
Loading

0 comments on commit 65f6884

Please sign in to comment.