Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): add AutoConnection #11

Merged
merged 14 commits into from
Sep 16, 2023
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ hyper = "=1.0.0-rc.4"
futures-channel = "0.3"
futures-util = { version = "0.3", default-features = false }
http = "0.2"
http-body = "1.0.0-rc.2"
bytes = "1"

once_cell = "1.14"

Expand All @@ -30,9 +32,11 @@ tower-service = "0.3"
tower = { version = "0.4", features = ["make", "util"] }

[dev-dependencies]
hyper = { version = "1.0.0-rc.3", features = ["full"] }
bytes = "1"
http-body-util = "0.1.0-rc.3"
tokio = { version = "1", features = ["macros", "test-util"] }
tokio-test = "0.4"

[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
pnet_datalink = "0.27.2"
Expand All @@ -50,6 +54,7 @@ http1 = ["hyper/http1"]
http2 = ["hyper/http2"]

tcp = []
auto = ["hyper/server", "http1", "http2"]
runtime = []

# internal features used in CI
Expand Down
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod exec;
#[cfg(feature = "client")]
mod lazy;
pub(crate) mod never;
pub(crate) mod rewind;
#[cfg(feature = "client")]
mod sync;

Expand Down
161 changes: 161 additions & 0 deletions src/common/rewind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::marker::Unpin;
use std::{cmp, io};

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::{
pin::Pin,
task::{self, Poll},
};

/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub(crate) struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}

impl<T> Rewind<T> {
#[cfg(test)]
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}

#[allow(dead_code)]
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}

#[cfg(test)]
pub(crate) fn rewind(&mut self, bs: Bytes) {
debug_assert!(self.pre.is_none());
self.pre = Some(bs);
}

// pub(crate) fn into_inner(self) -> (T, Bytes) {
// (self.inner, self.pre.unwrap_or_else(Bytes::new))
// }

// pub(crate) fn get_mut(&mut self) -> &mut T {
// &mut self.inner
// }
}

impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
programatik29 marked this conversation as resolved.
Show resolved Hide resolved
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}

return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

#[cfg(test)]
mod tests {
// FIXME: re-implement tests with `async/await`, this import should
// trigger a warning to remind us
use super::Rewind;
use bytes::Bytes;
use tokio::io::AsyncReadExt;

#[cfg(not(miri))]
#[tokio::test]
async fn partial_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

// Read off some bytes, ensure we filled o1
let mut buf = [0; 2];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// At this point we should have read everything that was in the MockStream
assert_eq!(&buf, &underlying);
}

#[cfg(not(miri))]
#[tokio::test]
async fn full_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]

//! hyper-util

#[cfg(feature = "client")]
pub mod client;
mod common;
pub mod rt;
pub mod server;

mod error;
Loading