Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Request<T> where T is a type that implements hyper::body::Body trait #1263

Merged
merged 12 commits into from
Sep 26, 2024
2 changes: 2 additions & 0 deletions juniper_hyper/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ All user visible changes to `juniper_hyper` crate will be documented in this fil
- Switched to 0.16 version of [`juniper` crate].
- Switched to 1 version of [`hyper` crate]. ([#1217])
- Changed return type of all functions from `Response<Body>` to `Response<String>`. ([#1101], [#1096])
- Add support for `Request<T>` where `T` is a type that implements `hyper::body::Body` trait. ([#1263])

[#1096]: /../../issues/1096
[#1101]: /../../pull/1101
[#1217]: /../../pull/1217
[#1263]: /../../pull/1263



Expand Down
123 changes: 92 additions & 31 deletions juniper_hyper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ use juniper::{
use serde_json::error::Error as SerdeError;
use url::form_urlencoded;

pub async fn graphql_sync<CtxT, QueryT, MutationT, SubscriptionT, S>(
pub async fn graphql_sync<
CtxT,
QueryT,
MutationT,
SubscriptionT,
S,
T: body::Body<Error = impl Error + 'static>,
>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
req: Request<body::Incoming>,
req: Request<T>,
) -> Response<String>
where
QueryT: GraphQLType<S, Context = CtxT>,
Expand All @@ -36,10 +43,17 @@ where
}
}

pub async fn graphql<CtxT, QueryT, MutationT, SubscriptionT, S>(
pub async fn graphql<
CtxT,
QueryT,
MutationT,
SubscriptionT,
S,
T: body::Body<Error = impl Error + 'static>,
>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
req: Request<body::Incoming>,
req: Request<T>,
) -> Response<String>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT>,
Expand All @@ -57,8 +71,8 @@ where
}
}

async fn parse_req<S: ScalarValue>(
req: Request<body::Incoming>,
async fn parse_req<S: ScalarValue, T: body::Body<Error = impl Error + 'static>>(
req: Request<T>,
) -> Result<GraphQLBatchRequest<S>, Response<String>> {
match *req.method() {
Method::GET => parse_get_req(req),
Expand All @@ -78,9 +92,9 @@ async fn parse_req<S: ScalarValue>(
.map_err(render_error)
}

fn parse_get_req<S: ScalarValue>(
req: Request<body::Incoming>,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
fn parse_get_req<S: ScalarValue, T: body::Body<Error = impl Error + 'static>>(
req: Request<T>,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>> {
req.uri()
.query()
.map(|q| gql_request_from_get(q).map(GraphQLBatchRequest::Single))
Expand All @@ -91,9 +105,9 @@ fn parse_get_req<S: ScalarValue>(
})
}

async fn parse_post_json_req<S: ScalarValue>(
body: body::Incoming,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
async fn parse_post_json_req<S: ScalarValue, T: body::Body<Error = impl Error + 'static>>(
body: T,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>> {
let chunk = body
.collect()
.await
Expand All @@ -106,9 +120,9 @@ async fn parse_post_json_req<S: ScalarValue>(
.map_err(GraphQLRequestError::BodyJSONError)
}

async fn parse_post_graphql_req<S: ScalarValue>(
body: body::Incoming,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
async fn parse_post_graphql_req<S: ScalarValue, T: body::Body>(
body: T,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>> {
let chunk = body
.collect()
.await
Expand Down Expand Up @@ -143,7 +157,9 @@ pub async fn playground(
resp
}

fn render_error(err: GraphQLRequestError) -> Response<String> {
fn render_error<T: body::Body<Error = impl Error + 'static>>(
err: GraphQLRequestError<T>,
) -> Response<String> {
let mut resp = new_response(StatusCode::BAD_REQUEST);
*resp.body_mut() = err.to_string();
resp
Expand Down Expand Up @@ -211,7 +227,9 @@ where
resp
}

fn gql_request_from_get<S>(input: &str) -> Result<JuniperGraphQLRequest<S>, GraphQLRequestError>
fn gql_request_from_get<S, T: body::Body>(
input: &str,
) -> Result<JuniperGraphQLRequest<S>, GraphQLRequestError<T>>
where
S: ScalarValue,
{
Expand Down Expand Up @@ -254,7 +272,7 @@ where
}
}

fn invalid_err(parameter_name: &str) -> GraphQLRequestError {
fn invalid_err<T: body::Body>(parameter_name: &str) -> GraphQLRequestError<T> {
GraphQLRequestError::Invalid(format!(
"`{parameter_name}` parameter is specified multiple times",
))
Expand All @@ -276,15 +294,15 @@ fn new_html_response(code: StatusCode) -> Response<String> {
}

#[derive(Debug)]
enum GraphQLRequestError {
BodyHyper(hyper::Error),
enum GraphQLRequestError<T: body::Body> {
BodyHyper(T::Error),
BodyUtf8(FromUtf8Error),
BodyJSONError(SerdeError),
Variables(SerdeError),
Invalid(String),
}

impl fmt::Display for GraphQLRequestError {
impl<T: body::Body<Error = impl Error + 'static>> fmt::Display for GraphQLRequestError<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GraphQLRequestError::BodyHyper(err) => fmt::Display::fmt(err, f),
Expand All @@ -296,7 +314,10 @@ impl fmt::Display for GraphQLRequestError {
}
}

impl Error for GraphQLRequestError {
impl<T: body::Body<Error = impl Error + 'static> + std::fmt::Debug> Error for GraphQLRequestError<T>
where
<T as body::Body>::Error: std::fmt::Debug,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
GraphQLRequestError::BodyHyper(err) => Some(err),
Expand All @@ -314,7 +335,11 @@ mod tests {
convert::Infallible, error::Error, net::SocketAddr, panic, sync::Arc, time::Duration,
};

use hyper::{server::conn::http1, service::service_fn, Method, Response, StatusCode};
use http_body_util::BodyExt;
use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Method, Request, Response,
StatusCode,
};
use hyper_util::rt::TokioIo;
use juniper::{
http::tests as http_tests,
Expand Down Expand Up @@ -376,8 +401,12 @@ mod tests {
}
}

async fn run_hyper_integration(is_sync: bool) {
let port = if is_sync { 3002 } else { 3001 };
static mut PORT: u16 = 3001;
async fn run_hyper_integration(is_sync: bool, is_custom_type: bool) {
let port = unsafe {
tyranron marked this conversation as resolved.
Show resolved Hide resolved
PORT = PORT.wrapping_add(1);
PORT
} + if is_sync { 1000 } else { 0 };
let addr = SocketAddr::from(([127, 0, 0, 1], port));

let db = Arc::new(Database::new());
Expand Down Expand Up @@ -405,7 +434,7 @@ mod tests {
if let Err(e) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
service_fn(move |req: Request<Incoming>| {
let root_node = root_node.clone();
let db = db.clone();
let matches = {
Expand All @@ -419,10 +448,30 @@ mod tests {
};
async move {
Ok::<_, Infallible>(if matches {
if is_sync {
super::graphql_sync(root_node, db, req).await
if is_custom_type {
let (parts, mut body) = req.into_parts();
let body = {
let mut buf = String::new();
if let Some(Ok(frame)) = body.frame().await {
if let Ok(bytes) = frame.into_data() {
buf = String::from_utf8_lossy(&bytes)
tyranron marked this conversation as resolved.
Show resolved Hide resolved
.to_string();
}
}
buf
};
let req = Request::from_parts(parts, body);
if is_sync {
super::graphql_sync(root_node, db, req).await
} else {
super::graphql(root_node, db, req).await
}
} else {
super::graphql(root_node, db, req).await
if is_sync {
super::graphql_sync(root_node, db, req).await
} else {
super::graphql(root_node, db, req).await
}
}
} else {
let mut resp = Response::new(String::new());
Expand Down Expand Up @@ -460,11 +509,23 @@ mod tests {

#[tokio::test]
async fn test_hyper_integration() {
run_hyper_integration(false).await
run_hyper_integration(false, false).await
}

#[tokio::test]
async fn test_sync_hyper_integration() {
run_hyper_integration(true).await
run_hyper_integration(true, false).await
}

#[tokio::test]
/// run test for a custom request type - `Request<Vec<u8>>`
async fn test_custom_hyper_integration() {
run_hyper_integration(false, false).await
}

#[tokio::test]
/// run test for a custom request type - `Request<Vec<u8>>` in sync mode
async fn test_custom_sync_hyper_integration() {
run_hyper_integration(true, true).await
}
}