Skip to content

Commit

Permalink
make server auto::Connection use hyper's IO instead of Tokios (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar authored Nov 16, 2023
1 parent a89ea05 commit 02dc44f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
47 changes: 37 additions & 10 deletions src/common/rewind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::Unpin;
use std::{cmp, io};

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use hyper::rt::{Read, ReadBufCursor, Write};

use std::{
pin::Pin,
Expand Down Expand Up @@ -48,21 +48,21 @@ impl<T> Rewind<T> {
// }
}

impl<T> AsyncRead for Rewind<T>
impl<T> Read for Rewind<T>
where
T: AsyncRead + Unpin,
T: Read + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
mut buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.remaining());
let copy_len = cmp::min(prefix.len(), remaining(&mut buf));
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
put_slice(&mut buf, &prefix[..copy_len]);
prefix.advance(copy_len);
// Put back what's left
if !prefix.is_empty() {
Expand All @@ -76,9 +76,36 @@ where
}
}

impl<T> AsyncWrite for Rewind<T>
fn remaining(cursor: &mut ReadBufCursor<'_>) -> usize {
// SAFETY:
// We do not uninitialize any set bytes.
unsafe { cursor.as_mut().len() }
}

// Copied from `ReadBufCursor::put_slice`.
// If that becomes public, we could ditch this.
fn put_slice(cursor: &mut ReadBufCursor<'_>, slice: &[u8]) {
assert!(
remaining(cursor) >= slice.len(),
"buf.len() must fit in remaining()"
);

let amt = slice.len();

// SAFETY:
// the length is asserted above
unsafe {
cursor.as_mut()[..amt]
.as_mut_ptr()
.cast::<u8>()
.copy_from_nonoverlapping(slice.as_ptr(), amt);
cursor.advance(amt);
}
}

impl<T> Write for Rewind<T>
where
T: AsyncWrite + Unpin,
T: Write + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -109,10 +136,9 @@ where
}
}

/*
#[cfg(test)]
mod tests {
// FIXME: re-implement tests with `async/await`, this import should
// trigger a warning to remind us
use super::Rewind;
use bytes::Bytes;
use tokio::io::AsyncReadExt;
Expand Down Expand Up @@ -159,3 +185,4 @@ mod tests {
stream.read_exact(&mut buf).await.expect("read1");
}
}
*/
34 changes: 17 additions & 17 deletions src/server/conn/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ use http::{Request, Response};
use http_body::Body;
use hyper::{
body::Incoming,
rt::{bounds::Http2ServerConnExec, Timer},
rt::{bounds::Http2ServerConnExec, Read, ReadBuf, Timer, Write},
server::conn::{http1, http2},
service::Service,
};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::{common::rewind::Rewind, rt::TokioIo};
use crate::common::rewind::Rewind;

type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

Expand Down Expand Up @@ -74,11 +73,10 @@ impl<E> Builder<E> {
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + 'static,
I: Read + Write + Unpin + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
let (version, io) = read_version(io).await?;
let io = TokioIo::new(io);
match version {
Version::H1 => self.http1.serve_connection(io, service).await?,
Version::H2 => self.http2.serve_connection(io, service).await?,
Expand All @@ -98,11 +96,10 @@ impl<E> Builder<E> {
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
I: Read + Write + Unpin + Send + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
let (version, io) = read_version(io).await?;
let io = TokioIo::new(io);
match version {
Version::H1 => {
self.http1
Expand All @@ -123,12 +120,14 @@ enum Version {
}
async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind<A>)>
where
A: AsyncRead + Unpin,
A: Read + Unpin,
{
let mut buf = [0; 24];
use std::mem::MaybeUninit;

let mut buf = [MaybeUninit::uninit(); 24];
let (version, buf) = ReadVersion {
reader: &mut reader,
buf: ReadBuf::new(&mut buf),
buf: ReadBuf::uninit(&mut buf),
version: Version::H1,
_pin: PhantomPinned,
}
Expand All @@ -148,21 +147,21 @@ pin_project! {

impl<A> Future for ReadVersion<'_, A>
where
A: AsyncRead + Unpin + ?Sized,
A: Read + Unpin + ?Sized,
{
type Output = IoResult<(Version, Vec<u8>)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<(Version, Vec<u8>)>> {
let this = self.project();

while this.buf.remaining() != 0 {
while this.buf.filled().len() < H2_PREFACE.len() {
if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] {
return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec())));
}
// if our buffer is empty, then we need to read some data to continue.
let rem = this.buf.remaining();
ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?;
if this.buf.remaining() == rem {
let len = this.buf.filled().len();
ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf.unfilled()))?;
if this.buf.filled().len() == len {
return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into();
}
}
Expand Down Expand Up @@ -302,7 +301,7 @@ impl<E> Http1Builder<'_, E> {
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + 'static,
I: Read + Write + Unpin + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
Expand Down Expand Up @@ -450,7 +449,7 @@ impl<E> Http2Builder<'_, E> {
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + 'static,
I: Read + Write + Unpin + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
Expand Down Expand Up @@ -562,6 +561,7 @@ mod tests {
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let _ = auto::Builder::new(TokioExecutor::new())
.serve_connection(stream, service_fn(hello))
Expand Down

0 comments on commit 02dc44f

Please sign in to comment.