From 8854e660e9ab07404e5bb8e30b92311d3848de05 Mon Sep 17 00:00:00 2001 From: Rakshith Ravi Date: Sun, 1 Oct 2023 14:16:22 +0530 Subject: [PATCH] Add HTTP/1 and HTTP/2 to `axum::serve` (#2241) --- axum/Cargo.toml | 5 +- axum/src/extract/ws.rs | 3 +- axum/src/hyper1_tokio_io.rs | 161 ------------------------------------ axum/src/lib.rs | 2 - axum/src/serve.rs | 15 ++-- 5 files changed, 13 insertions(+), 173 deletions(-) delete mode 100644 axum/src/hyper1_tokio_io.rs diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 469a1c5ebd..e0062a099f 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -22,7 +22,7 @@ matched-path = [] multipart = ["dep:multer"] original-uri = [] query = ["dep:serde_urlencoded"] -tokio = ["dep:tokio", "hyper/server", "hyper/tcp", "hyper/runtime", "tower/make"] +tokio = ["dep:tokio", "dep:hyper-util", "hyper/server", "hyper/tcp", "hyper/runtime", "tower/make"] tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] ws = ["tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] @@ -51,12 +51,13 @@ tower-layer = "0.3.2" tower-service = "0.3" # wont need this when axum uses http-body 1.0 -hyper1 = { package = "hyper", version = "=1.0.0-rc.4", features = ["server", "http1"] } +hyper1 = { package = "hyper", version = "=1.0.0-rc.4", features = ["server", "http1", "http2"] } tower-hyper-http-body-compat = { version = "0.2", features = ["server", "http1"] } # optional dependencies axum-macros = { path = "../axum-macros", version = "0.3.7", optional = true } base64 = { version = "0.21.0", optional = true } +hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "d97181a", features = ["auto"], optional = true } multer = { version = "2.0.0", optional = true } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 966776aca9..a5f20a6abb 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -92,7 +92,7 @@ use self::rejection::*; use super::FromRequestParts; -use crate::{body::Bytes, hyper1_tokio_io::TokioIo, response::Response, Error}; +use crate::{body::Bytes, response::Response, Error}; use async_trait::async_trait; use axum_core::body::Body; use futures_util::{ @@ -104,6 +104,7 @@ use http::{ request::Parts, Method, StatusCode, }; +use hyper_util::rt::TokioIo; use sha1::{Digest, Sha1}; use std::{ borrow::Cow, diff --git a/axum/src/hyper1_tokio_io.rs b/axum/src/hyper1_tokio_io.rs deleted file mode 100644 index 474df41217..0000000000 --- a/axum/src/hyper1_tokio_io.rs +++ /dev/null @@ -1,161 +0,0 @@ -// Copied from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio_io.rs - -#![allow(unsafe_code)] - -//! Tokio IO integration for hyper -use hyper1 as hyper; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use pin_project_lite::pin_project; - -pin_project! { - /// A wrapping implementing hyper IO traits for a type that - /// implements Tokio's IO traits. - #[derive(Debug)] - pub(crate) struct TokioIo { - #[pin] - inner: T, - } -} - -impl TokioIo { - /// Wrap a type implementing Tokio's IO traits. - pub(crate) fn new(inner: T) -> Self { - Self { inner } - } - - /// Borrow the inner type. - pub(crate) fn inner(&self) -> &T { - &self.inner - } -} - -impl hyper::rt::Read for TokioIo -where - T: tokio::io::AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let n = unsafe { - let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); - match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { - Poll::Ready(Ok(())) => tbuf.filled().len(), - other => return other, - } - }; - - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl hyper::rt::Write for TokioIo -where - T: tokio::io::AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - tokio::io::AsyncWrite::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } -} - -impl tokio::io::AsyncRead for TokioIo -where - T: hyper::rt::Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tbuf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let filled = tbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); - - match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; - - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tbuf.assume_init(n_init); - tbuf.set_filled(n_filled); - } - - Poll::Ready(Ok(())) - } -} - -impl tokio::io::AsyncWrite for TokioIo -where - T: hyper::rt::Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - hyper::rt::Write::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - hyper::rt::Write::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - hyper::rt::Write::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - hyper::rt::Write::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) - } -} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 3a64dd73d1..71eb74fba2 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -432,8 +432,6 @@ mod boxed; mod extension; #[cfg(feature = "form")] mod form; -#[cfg(feature = "tokio")] -mod hyper1_tokio_io; #[cfg(feature = "json")] mod json; mod service_ext; diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 7f825ef6db..f2b49e2e8c 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -2,10 +2,12 @@ use std::{convert::Infallible, io, net::SocketAddr}; -use crate::hyper1_tokio_io::TokioIo; use axum_core::{body::Body, extract::Request, response::Response}; use futures_util::{future::poll_fn, FutureExt}; -use hyper1::server::conn::http1; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::Builder, +}; use tokio::net::{TcpListener, TcpStream}; use tower_hyper_http_body_compat::{HttpBody04ToHttpBody1, HttpBody1ToHttpBody04}; use tower_service::Service; @@ -15,7 +17,7 @@ use tower_service::Service; /// This method of running a service is intentionally simple and doesn't support any configuration. /// Use hyper or hyper-util if you need configuration. /// -/// It only supports HTTP/1. +/// It supports both HTTP/1 as well as HTTP/2. /// /// # Examples /// @@ -138,10 +140,9 @@ where }); tokio::task::spawn(async move { - match http1::Builder::new() - .serve_connection(tcp_stream, service) - // for websockets - .with_upgrades() + match Builder::new(TokioExecutor::new()) + // upgrades needed for websockets + .serve_connection_with_upgrades(tcp_stream.into_inner(), service) .await { Ok(()) => {}