Skip to content

Commit

Permalink
fix(#360): data not flushed immediatly on reverse tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Sep 27, 2024
1 parent b1e0982 commit 0366502
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
10 changes: 8 additions & 2 deletions src/tunnel/client/l4_transport_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;

pub enum TransportStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
Tls(tokio_rustls::client::TlsStream<TcpStream>),
TlsSrv(tokio_rustls::server::TlsStream<TcpStream>),
}

impl AsyncRead for TransportStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
Self::TlsSrv(cnx) => Pin::new(cnx).poll_read(cx, buf),
}
}
}
Expand All @@ -24,20 +25,23 @@ impl AsyncWrite for TransportStream {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
Self::TlsSrv(cnx) => Pin::new(cnx).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::TlsSrv(cnx) => Pin::new(cnx).poll_shutdown(cx),
}
}

Expand All @@ -49,13 +53,15 @@ impl AsyncWrite for TransportStream {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
}
}

fn is_write_vectored(&self) -> bool {
match &self {
Self::Plain(cnx) => cnx.is_write_vectored(),
Self::Tls(cnx) => cnx.is_write_vectored(),
Self::TlsSrv(cnx) => cnx.is_write_vectored(),
}
}
}
10 changes: 9 additions & 1 deletion src/tunnel/server/handler_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};
use hyper::{Request, Response};
use std::net::SocketAddr;
use std::sync::Arc;
use fastwebsockets::Role;
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio_rustls::server::TlsStream;
use tracing::{error, warn, Instrument, Span};
use crate::tunnel::client::l4_transport_stream::TransportStream;

pub(super) async fn ws_server_upgrade(
server: WsServer,
Expand Down Expand Up @@ -46,7 +51,10 @@ pub(super) async fn ws_server_upgrade(
tokio::spawn(
async move {
let (ws_rx, ws_tx) = match fut.await {
Ok(mut ws) => {
Ok(ws) => {
let tcp_inner = ws.into_inner().into_inner().downcast::<TokioIo<TlsStream<TcpStream>>>().unwrap();
let tcp_inner = TransportStream::TlsSrv(tcp_inner.io.into_inner());
let mut ws = fastwebsockets::WebSocket::after_handshake(tcp_inner, Role::Server);
ws.set_auto_pong(false);
ws.set_auto_close(false);
ws.set_auto_apply_mask(mask_frame);
Expand Down
20 changes: 12 additions & 8 deletions src/tunnel/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocketRead, WebSocketWrite};
use http_body_util::Empty;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::http::response::Parts;
use hyper::upgrade::Upgraded;
use hyper::Request;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use log::debug;
use log::{debug, error, info};
use std::io;
use std::io::ErrorKind;
use std::ops::DerefMut;
Expand All @@ -26,9 +25,10 @@ use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::Notify;
use tracing::trace;
use uuid::Uuid;
use crate::tunnel::client::l4_transport_stream::TransportStream;

pub struct WebsocketTunnelWrite {
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
inner: WebSocketWrite<WriteHalf<TransportStream>>,
buf: BytesMut,
pending_operations: Receiver<Frame<'static>>,
pending_ops_notify: Arc<Notify>,
Expand All @@ -37,7 +37,7 @@ pub struct WebsocketTunnelWrite {

impl WebsocketTunnelWrite {
pub fn new(
ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
ws: WebSocketWrite<WriteHalf<TransportStream>>,
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
) -> Self {
Self {
Expand All @@ -59,6 +59,7 @@ impl TunnelWrite for WebsocketTunnelWrite {
let read_len = self.buf.len();
let buf = &mut self.buf;

info!("write {:?} bytes", String::from_utf8(buf[read_len-10..read_len].to_vec()));
let ret = self
.inner
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buf[..read_len])))
Expand Down Expand Up @@ -146,13 +147,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
}

pub struct WebsocketTunnelRead {
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
inner: WebSocketRead<ReadHalf<TransportStream>>,
pending_operations: Sender<Frame<'static>>,
notify_pending_ops: Arc<Notify>,
}

impl WebsocketTunnelRead {
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
pub fn new(ws: WebSocketRead<ReadHalf<TransportStream>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
let (tx, rx) = tokio::sync::mpsc::channel(10);
let notify = Arc::new(Notify::new());
(
Expand Down Expand Up @@ -278,10 +279,13 @@ pub async fn connect(
})?;
debug!("with HTTP upgrade request {:?}", req);
let transport = pooled_cnx.deref_mut().take().unwrap();
let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
.await
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;

let tcp_inner = ws.into_inner().into_inner().downcast::<TokioIo<TransportStream>>().unwrap();
error!("tcp_inner {:?}", tcp_inner.read_buf);
let mut ws = fastwebsockets::WebSocket::after_handshake(tcp_inner.io.into_inner(), Role::Client);
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
ws.set_auto_close(false);
ws.set_auto_pong(false);
Expand Down

0 comments on commit 0366502

Please sign in to comment.