Skip to content

Commit

Permalink
Use auto_pong and auto_close instead of creating our own frames
Browse files Browse the repository at this point in the history
  • Loading branch information
r-vdp committed Jul 14, 2024
1 parent bbbf649 commit c93cb6b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 43 deletions.
6 changes: 1 addition & 5 deletions src/tunnel/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,7 @@ async fn ws_server_upgrade(
tokio::spawn(
async move {
let (ws_rx, mut ws_tx) = match fut.await {
Ok(mut ws) => {
ws.set_auto_pong(false);
ws.set_auto_close(false);
ws.split(tokio::io::split)
}
Ok(ws) => ws.split(tokio::io::split),
Err(err) => {
error!("Error during http upgrade request: {:?}", err);
return;
Expand Down
60 changes: 22 additions & 38 deletions src/tunnel/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ use tracing::{debug, trace};
use uuid::Uuid;

// Messages that can be passed from the reader half to the writer half
#[derive(Debug)]
pub enum WebSocketTunnelMessage {
Ping(u8),
Pong(u8),
Close,
SendFrame(Frame<'static>),
}

#[derive(Debug)]
Expand Down Expand Up @@ -166,20 +164,11 @@ impl TunnelWrite for WebsocketTunnelWrite {
self.ping_state.set_pong_seq(seq);
Ok(())
}
Ok(WebSocketTunnelMessage::Ping(seq)) => {
debug!("Sending pong({})", seq);
self.inner
.write_frame(Frame::pong(Payload::BorrowedMut(&mut [seq])))
.await
.map_err(|err| io::Error::new(ErrorKind::BrokenPipe, err))
}
Ok(WebSocketTunnelMessage::Close) => {
debug!("Sending close confirmation");
self.inner
.write_frame(Frame::close(1000, &[]))
.await
.map_err(|err| io::Error::new(ErrorKind::BrokenPipe, err))
}
Ok(WebSocketTunnelMessage::SendFrame(frame)) => self
.inner
.write_frame(frame)
.await
.map_err(|err| io::Error::new(ErrorKind::BrokenPipe, err)),
Err(TryRecvError::Empty) => Ok(()),
Err(TryRecvError::Disconnected) => Err(io::Error::new(ErrorKind::BrokenPipe, "channel closed")),
}
Expand Down Expand Up @@ -212,15 +201,20 @@ impl WebsocketTunnelRead {
}
}

// Since we disable auto_pong and auto_close, we should never end up here.
// So let's panic so that we don't accidentally end up calling this.
fn frame_reader(_: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>> {
unimplemented!()
async fn send_frame(sender: &Sender<WebSocketTunnelMessage>, frame: Frame<'static>) -> Result<(), io::Error> {
sender
.send(WebSocketTunnelMessage::SendFrame(frame))
.await
.map_err(|err| io::Error::new(ErrorKind::ConnectionAborted, err))
}

impl TunnelRead for WebsocketTunnelRead {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
let msg = match self.inner.read_frame(&mut frame_reader).await {
let msg = match self
.inner
.read_frame(&mut |f| send_frame(&self.send_to_writer, f))
.await
{
Ok(msg) => msg,
Err(err) => return Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
};
Expand All @@ -233,20 +227,6 @@ impl TunnelRead for WebsocketTunnelRead {
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
}
}
OpCode::Close => {
// Sending back the close confirmation is best effort, if we fail, we just close
// the connection anyway
_ = self.send_to_writer.send(WebSocketTunnelMessage::Close).await;
Err(io::Error::new(ErrorKind::NotConnected, "websocket close"))
}
OpCode::Ping => {
let seq = msg.payload[0];
debug!("Received ping({})", seq);
self.send_to_writer
.send(WebSocketTunnelMessage::Ping(seq))
.await
.map_err(|err| io::Error::new(ErrorKind::ConnectionAborted, err))
}
OpCode::Pong => {
let seq = msg.payload[0];
debug!("Received pong({})", seq);
Expand All @@ -255,6 +235,12 @@ impl TunnelRead for WebsocketTunnelRead {
.await
.map_err(|err| io::Error::new(ErrorKind::ConnectionAborted, err))
}
// We use auto_close so the write half will automatically get closed
// and we automatically send a close frame back.
// We can just break out of the read loop here.
OpCode::Close => Err(io::Error::new(ErrorKind::NotConnected, "Connection closed")),
// We use auto_pong , so we'll never see this variant
OpCode::Ping => unimplemented!(),
}
}
}
Expand Down Expand Up @@ -319,8 +305,6 @@ pub async fn connect(
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;

ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
ws.set_auto_pong(false);
ws.set_auto_close(false);

let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (ch_tx, ch_rx) = mpsc::channel::<WebSocketTunnelMessage>(32);
Expand Down

0 comments on commit c93cb6b

Please sign in to comment.