Skip to content

Commit

Permalink
Make the layer axum compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusTieger committed Jan 22, 2025
1 parent 88aed0e commit 058fb6e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
4 changes: 4 additions & 0 deletions tonic-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ tower-service = "0.3"
tower-layer = "0.3"
tracing = "0.1"

[dependencies.axum]
version = "0.8.1"
optional = true

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
tower-http = { version = "0.6", features = ["cors"] }
Expand Down
77 changes: 62 additions & 15 deletions tonic-web/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ use std::task::{ready, Context, Poll};
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
use pin_project::pin_project;
use tonic::metadata::GRPC_CONTENT_TYPE;
use tonic::{body::Body, server::NamedService};
use tonic::{server::NamedService};
use tower_service::Service;
use tracing::{debug, trace};

use crate::call::content_types::is_grpc_web;
use crate::call::{Encoding, GrpcWebCall};

use bytes::Bytes;

/// Service implementing the grpc-web protocol.
#[derive(Debug, Clone)]
pub struct GrpcWebService<S> {
Expand Down Expand Up @@ -45,9 +47,9 @@ impl<S> GrpcWebService<S> {

impl<S, B> Service<Request<B>> for GrpcWebService<S>
where
S: Service<Request<Body>, Response = Response<Body>>,
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
S: Service<Request<B>, Response = Response<B>>,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + std::error::Error + fmt::Display + Send + Sync,
{
type Response = S::Response;
type Error = S::Error;
Expand Down Expand Up @@ -100,7 +102,7 @@ where
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
ResponseFuture {
case: Case::Other {
future: self.inner.call(req.map(Body::new)),
future: self.inner.call(req.map(B::new)),
},
}
}
Expand Down Expand Up @@ -152,11 +154,13 @@ impl<F> Case<F> {
}
}

impl<F, E> Future for ResponseFuture<F>
impl<F, E, A> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<Body>, E>>,
F: Future<Output = Result<Response<A>, E>>,
A: BoxedBody + 'static,
<A as http_body::Body>::Error: std::error::Error + Send + Sync + 'static
{
type Output = Result<Response<Body>, E>;
type Output = Result<Response<A>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
Expand All @@ -169,7 +173,7 @@ where
}
CaseProj::Other { future } => future.poll(cx),
CaseProj::ImmediateResponse { res } => {
let res = Response::from_parts(res.take().unwrap(), Body::empty());
let res = Response::from_parts(res.take().unwrap(), A::empty());
Poll::Ready(Ok(res))
}
}
Expand Down Expand Up @@ -203,9 +207,9 @@ impl<'a> RequestKind<'a> {
// Mutating request headers to conform to a gRPC request is not really
// necessary for us at this point. We could remove most of these except
// maybe for inserting `header::TE`, which tonic should check?
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<B>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
req.headers_mut().remove(header::CONTENT_LENGTH);
Expand All @@ -221,17 +225,17 @@ where
HeaderValue::from_static("identity,deflate,gzip"),
);

req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
req.map(|b| B::new(GrpcWebCall::request(b, encoding)))
}

fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<B>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
let mut res = res
.map(|b| GrpcWebCall::response(b, encoding))
.map(Body::new);
.map(B::new);

res.headers_mut().insert(
header::CONTENT_TYPE,
Expand All @@ -241,6 +245,49 @@ where
res
}

/// Alias for a type-erased error type.
type BoxError = Box<dyn std::error::Error + Send + Sync>;

trait BoxedBody: http_body::Body<Data = bytes::Bytes> + Send {

fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>;

fn empty() -> Self;

}

impl BoxedBody for tonic::body::Body {

fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError> {
Self::new(body)
}

fn empty() -> Self {
Self::empty()
}

}
#[cfg(feature = "axum")]
impl BoxedBody for axum::body::Body {

fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError> {
Self::new(body)
}

fn empty() -> Self {
Self::empty()
}

}
#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 058fb6e

Please sign in to comment.