Skip to content

Commit

Permalink
edge-http: make fields in {Req,Resp}Headers non-optional (#37)
Browse files Browse the repository at this point in the history
* edge-http: make fields in {Req,Resp}Headers non-optional

Ensures that the {Req,Resp}Headers is always complete and
adheres to the HTTP spec. Saves a lot code from dealing with
`Option`s.

Also, `send_status_line` is removed, and inlined into
`send_request` and `send_status`.

* Fix is_upgrade_request

* Let receive be a method, rather than a constructor function

* Add inline attribute to default

* Restore the const constructor `new()` and call it in the library
  • Loading branch information
showier-drastic authored Oct 11, 2024
1 parent 71810bd commit 997b76f
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 170 deletions.
166 changes: 55 additions & 111 deletions edge-http/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,15 @@ impl<'b, const N: usize> RequestHeaders<'b, N> {
unreachable!("Should not happen. HTTP header parsing is indeterminate.")
}

self.http11 = if let Some(version) = parser.version {
if version > 1 {
Err(Error::InvalidHeaders)?;
}

Some(version == 1)
} else {
None
self.http11 = match parser.version {
Some(0) => false,
Some(1) => true,
_ => Err(Error::InvalidHeaders)?,
};

self.method = parser.method.and_then(Method::new);
self.path = parser.path;
let method_str = parser.method.ok_or(Error::InvalidHeaders)?;
self.method = Method::new(method_str).ok_or(Error::InvalidHeaders)?;
self.path = parser.path.ok_or(Error::InvalidHeaders)?;

trace!("Received:\n{}", self);

Expand All @@ -151,8 +148,7 @@ impl<'b, const N: usize> RequestHeaders<'b, N> {

/// Resolve the connection type and body type from the headers
pub fn resolve<E>(&self) -> Result<(ConnectionType, BodyType), Error<E>> {
self.headers
.resolve::<E>(None, true, self.http11.unwrap_or(false))
self.headers.resolve::<E>(None, true, self.http11)
}

/// Send the headers to the output stream, returning the connection type and body type
Expand All @@ -164,12 +160,10 @@ impl<'b, const N: usize> RequestHeaders<'b, N> {
where
W: Write,
{
let http11 = self.http11.unwrap_or(false);

send_request(http11, self.method, self.path, &mut output).await?;
send_request(self.http11, self.method, self.path, &mut output).await?;

self.headers
.send(None, true, http11, chunked_if_unspecified, output)
.send(None, true, self.http11, chunked_if_unspecified, output)
.await
}
}
Expand Down Expand Up @@ -199,17 +193,13 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> {
unreachable!("Should not happen. HTTP header parsing is indeterminate.")
}

self.http11 = if let Some(version) = parser.version {
if version > 1 {
Err(Error::InvalidHeaders)?;
}

Some(version == 1)
} else {
None
self.http11 = match parser.version {
Some(0) => false,
Some(1) => true,
_ => Err(Error::InvalidHeaders)?,
};

self.code = parser.code;
self.code = parser.code.ok_or(Error::InvalidHeaders)?;
self.reason = parser.reason;

trace!("Received:\n{}", self);
Expand All @@ -225,11 +215,8 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> {
&self,
request_connection_type: ConnectionType,
) -> Result<(ConnectionType, BodyType), Error<E>> {
self.headers.resolve::<E>(
Some(request_connection_type),
false,
self.http11.unwrap_or(false),
)
self.headers
.resolve::<E>(Some(request_connection_type), false, self.http11)
}

/// Send the headers to the output stream, returning the connection type and body type
Expand All @@ -242,15 +229,13 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> {
where
W: Write,
{
let http11 = self.http11.unwrap_or(false);

send_status(http11, self.code, self.reason, &mut output).await?;
send_status(self.http11, self.code, self.reason, &mut output).await?;

self.headers
.send(
Some(request_connection_type),
false,
http11,
self.http11,
chunked_if_unspecified,
output,
)
Expand All @@ -260,42 +245,56 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> {

pub(crate) async fn send_request<W>(
http11: bool,
method: Option<Method>,
path: Option<&str>,
output: W,
method: Method,
path: &str,
mut output: W,
) -> Result<(), Error<W::Error>>
where
W: Write,
{
raw::send_status_line(
true,
http11,
method.map(|method| method.as_str()),
path,
output,
)
.await
// RFC 9112: request-line = method SP request-target SP HTTP-version

output
.write_all(method.as_str().as_bytes())
.await
.map_err(Error::Io)?;
output.write_all(b" ").await.map_err(Error::Io)?;
output.write_all(path.as_bytes()).await.map_err(Error::Io)?;
output.write_all(b" ").await.map_err(Error::Io)?;
raw::send_version(&mut output, http11).await?;
output.write_all(b"\r\n").await.map_err(Error::Io)?;

Ok(())
}

pub(crate) async fn send_status<W>(
http11: bool,
status: Option<u16>,
status: u16,
reason: Option<&str>,
output: W,
mut output: W,
) -> Result<(), Error<W::Error>>
where
W: Write,
{
let status_str: Option<heapless::String<5>> = status.map(|status| status.try_into().unwrap());
// RFC 9112: status-line = HTTP-version SP status-code SP [ reason-phrase ]

raw::send_status_line(
false,
http11,
status_str.as_ref().map(|status| status.as_str()),
reason,
output,
)
.await
raw::send_version(&mut output, http11).await?;
output.write_all(b" ").await.map_err(Error::Io)?;
let status_str: heapless::String<5> = status.try_into().unwrap();
output
.write_all(status_str.as_bytes())
.await
.map_err(Error::Io)?;
output.write_all(b" ").await.map_err(Error::Io)?;
if let Some(reason) = reason {
output
.write_all(reason.as_bytes())
.await
.map_err(Error::Io)?;
}
output.write_all(b"\r\n").await.map_err(Error::Io)?;

Ok(())
}

pub(crate) async fn send_headers<'a, H, W>(
Expand Down Expand Up @@ -1181,61 +1180,6 @@ mod raw {
}
}

pub(crate) async fn send_status_line<W>(
request: bool,
http11: bool,
token: Option<&str>,
extra: Option<&str>,
mut output: W,
) -> Result<(), Error<W::Error>>
where
W: Write,
{
let mut written = false;

if !request {
send_version(&mut output, http11).await?;
written = true;
}

if let Some(token) = token {
if written {
output.write_all(b" ").await.map_err(Error::Io)?;
}

output
.write_all(token.as_bytes())
.await
.map_err(Error::Io)?;

written = true;
}

if written {
output.write_all(b" ").await.map_err(Error::Io)?;
}
if let Some(extra) = extra {
output
.write_all(extra.as_bytes())
.await
.map_err(Error::Io)?;

written = true;
}

if request {
if written {
output.write_all(b" ").await.map_err(Error::Io)?;
}

send_version(&mut output, http11).await?;
}

output.write_all(b"\r\n").await.map_err(Error::Io)?;

Ok(())
}

pub(crate) async fn send_version<W>(mut output: W, http11: bool) -> Result<(), Error<W::Error>>
where
W: Write,
Expand Down
6 changes: 2 additions & 4 deletions edge-http/src/io/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,15 @@ where
let mut state = self.unbind();

let result = async {
match send_request(http11, Some(method), Some(uri), state.io.as_mut().unwrap()).await {
match send_request(http11, method, uri, state.io.as_mut().unwrap()).await {
Ok(_) => (),
Err(Error::Io(_)) => {
if !fresh_connection {
// Attempt to reconnect and re-send the request
state.io = None;
state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);

send_request(http11, Some(method), Some(uri), state.io.as_mut().unwrap())
.await?;
send_request(http11, method, uri, state.io.as_mut().unwrap()).await?;
}
}
Err(other) => Err(other)?,
Expand Down Expand Up @@ -263,7 +262,6 @@ where

let mut state = self.unbind();
let buf_ptr: *mut [u8] = state.buf;

let mut response = ResponseHeaders::new();

match response
Expand Down
17 changes: 6 additions & 11 deletions edge-http/src/io/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ where
message: Option<&str>,
headers: &[(&str, &str)],
) -> Result<(), Error<T::Error>> {
self.complete_request(Some(status), message, headers).await
self.complete_request(status, message, headers).await
}

/// A convenience method to initiate a WebSocket upgrade response
Expand All @@ -125,7 +125,7 @@ where
/// If the connection is still in a request state, and empty 200 OK response is sent
pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
if self.is_request_initiated() {
self.complete_request(Some(200), Some("OK"), &[]).await?;
self.complete_request(200, Some("OK"), &[]).await?;
}

if self.is_response_initiated() {
Expand All @@ -145,7 +145,7 @@ where
Ok(_) => {
let headers = [("Connection", "Close"), ("Content-Type", "text/plain")];

self.complete_request(Some(500), Some("Internal Error"), &headers)
self.complete_request(500, Some("Internal Error"), &headers)
.await?;

let response = self.response_mut()?;
Expand Down Expand Up @@ -181,7 +181,7 @@ where

async fn complete_request(
&mut self,
status: Option<u16>,
status: u16,
reason: Option<&str>,
headers: &[(&str, &str)],
) -> Result<(), Error<T::Error>> {
Expand All @@ -190,7 +190,7 @@ where
let mut buf = [0; COMPLETION_BUF_SIZE];
while request.io.read(&mut buf).await? > 0 {}

let http11 = request.request.http11.unwrap_or(false);
let http11 = request.request.http11;
let request_connection_type = request.connection_type;

let mut io = self.unbind_mut();
Expand Down Expand Up @@ -918,12 +918,7 @@ mod embedded_svc_compat {
let headers = connection.headers().ok();

if let Some(headers) = headers {
if headers.path.map(|path| self.path == path).unwrap_or(false)
&& headers
.method
.map(|method| self.method == method.into())
.unwrap_or(false)
{
if headers.path == self.path && headers.method == self.method.into() {
return self.handler.handle(connection).await;
}
}
Expand Down
Loading

0 comments on commit 997b76f

Please sign in to comment.