From 193c2babc8f80577e8545e3e099f0a5538b6d80f Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 16 Nov 2023 12:39:10 -0500 Subject: [PATCH] make server auto::Connection use hyper's IO instead of Tokios --- src/common/rewind.rs | 48 ++++++++++++++++++++++++++++++++--------- src/server/conn/auto.rs | 34 ++++++++++++++--------------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/src/common/rewind.rs b/src/common/rewind.rs index 18d8f58..d1602e9 100644 --- a/src/common/rewind.rs +++ b/src/common/rewind.rs @@ -2,7 +2,7 @@ use std::marker::Unpin; use std::{cmp, io}; use bytes::{Buf, Bytes}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{Read, ReadBufCursor, Write}; use std::{ pin::Pin, @@ -48,21 +48,21 @@ impl Rewind { // } } -impl AsyncRead for Rewind +impl Read for Rewind where - T: AsyncRead + Unpin, + T: Read + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { if let Some(mut prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. if !prefix.is_empty() { - let copy_len = cmp::min(prefix.len(), buf.remaining()); + let copy_len = cmp::min(prefix.len(), remaining(&mut buf)); // TODO: There should be a way to do following two lines cleaner... - buf.put_slice(&prefix[..copy_len]); + put_slice(&mut buf, &prefix[..copy_len]); prefix.advance(copy_len); // Put back what's left if !prefix.is_empty() { @@ -76,9 +76,37 @@ where } } -impl AsyncWrite for Rewind + +fn remaining(cursor: &mut ReadBufCursor<'_>) -> usize { + // SAFETY: + // We do not uninitialize any set bytes. + unsafe { cursor.as_mut().len() } +} + +// Copied from `ReadBufCursor::put_slice`. +// If that becomes public, we could ditch this. +fn put_slice(cursor: &mut ReadBufCursor<'_>, slice: &[u8]) { + assert!( + remaining(cursor) >= slice.len(), + "buf.len() must fit in remaining()" + ); + + let amt = slice.len(); + + // SAFETY: + // the length is asserted above + unsafe { + cursor.as_mut()[..amt] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(slice.as_ptr(), amt); + cursor.advance(amt); + } +} + +impl Write for Rewind where - T: AsyncWrite + Unpin, + T: Write + Unpin, { fn poll_write( mut self: Pin<&mut Self>, @@ -109,10 +137,9 @@ where } } +/* #[cfg(test)] mod tests { - // FIXME: re-implement tests with `async/await`, this import should - // trigger a warning to remind us use super::Rewind; use bytes::Bytes; use tokio::io::AsyncReadExt; @@ -159,3 +186,4 @@ mod tests { stream.read_exact(&mut buf).await.expect("read1"); } } +*/ diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs index e10c201..4761c45 100644 --- a/src/server/conn/auto.rs +++ b/src/server/conn/auto.rs @@ -13,14 +13,13 @@ use http::{Request, Response}; use http_body::Body; use hyper::{ body::Incoming, - rt::{bounds::Http2ServerConnExec, Timer}, + rt::{bounds::Http2ServerConnExec, Read, ReadBuf, Timer, Write}, server::conn::{http1, http2}, service::Service, }; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use crate::{common::rewind::Rewind, rt::TokioIo}; +use crate::common::rewind::Rewind; type Result = std::result::Result>; @@ -74,11 +73,10 @@ impl Builder { B: Body + Send + 'static, B::Data: Send, B::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, E: Http2ServerConnExec, { let (version, io) = read_version(io).await?; - let io = TokioIo::new(io); match version { Version::H1 => self.http1.serve_connection(io, service).await?, Version::H2 => self.http2.serve_connection(io, service).await?, @@ -98,11 +96,10 @@ impl Builder { B: Body + Send + 'static, B::Data: Send, B::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + I: Read + Write + Unpin + Send + 'static, E: Http2ServerConnExec, { let (version, io) = read_version(io).await?; - let io = TokioIo::new(io); match version { Version::H1 => { self.http1 @@ -123,12 +120,14 @@ enum Version { } async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind)> where - A: AsyncRead + Unpin, + A: Read + Unpin, { - let mut buf = [0; 24]; + use std::mem::MaybeUninit; + + let mut buf = [MaybeUninit::uninit(); 24]; let (version, buf) = ReadVersion { reader: &mut reader, - buf: ReadBuf::new(&mut buf), + buf: ReadBuf::uninit(&mut buf), version: Version::H1, _pin: PhantomPinned, } @@ -148,21 +147,21 @@ pin_project! { impl Future for ReadVersion<'_, A> where - A: AsyncRead + Unpin + ?Sized, + A: Read + Unpin + ?Sized, { type Output = IoResult<(Version, Vec)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll)>> { let this = self.project(); - while this.buf.remaining() != 0 { + while this.buf.filled().len() < H2_PREFACE.len() { if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] { return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec()))); } // if our buffer is empty, then we need to read some data to continue. - let rem = this.buf.remaining(); - ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?; - if this.buf.remaining() == rem { + let len = this.buf.filled().len(); + ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf.unfilled()))?; + if this.buf.filled().len() == len { return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into(); } } @@ -302,7 +301,7 @@ impl Http1Builder<'_, E> { B: Body + Send + 'static, B::Data: Send, B::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, E: Http2ServerConnExec, { self.inner.serve_connection(io, service).await @@ -450,7 +449,7 @@ impl Http2Builder<'_, E> { B: Body + Send + 'static, B::Data: Send, B::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, E: Http2ServerConnExec, { self.inner.serve_connection(io, service).await @@ -562,6 +561,7 @@ mod tests { tokio::spawn(async move { loop { let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioIo::new(stream); tokio::task::spawn(async move { let _ = auto::Builder::new(TokioExecutor::new()) .serve_connection(stream, service_fn(hello))