diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1a69f7f7c4..351b12a030 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,12 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **fixed:** Include port number when parsing authority ([#2242]) - **change:** Avoid cloning `Arc` during deserialization of `Path` - **added:** `axum::serve::Serve::tcp_nodelay` and `axum::serve::WithGracefulShutdown::tcp_nodelay` ([#2653]) - **added:** `Router::has_routes` function ([#2790]) - **change:** Update tokio-tungstenite to 0.23 ([#2841]) - **added:** `Serve::local_addr` and `WithGracefulShutdown::local_addr` functions ([#2881]) +[#2242]: https://github.com/tokio-rs/axum/pull/2242 [#2653]: https://github.com/tokio-rs/axum/pull/2653 [#2790]: https://github.com/tokio-rs/axum/pull/2790 [#2841]: https://github.com/tokio-rs/axum/pull/2841 diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index f1d179a545..e5e02b8618 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -6,17 +6,21 @@ use async_trait::async_trait; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, + uri::Authority, }; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; -/// Extractor that resolves the hostname of the request. +/// Extractor that resolves the host of the request. /// -/// Hostname is resolved through the following, in order: +/// Host is resolved through the following, in order: /// - `Forwarded` header /// - `X-Forwarded-Host` header /// - `Host` header -/// - request target / URI +/// - Authority of the request URI +/// +/// See for the definition of +/// host. /// /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. @@ -51,8 +55,8 @@ where return Ok(Host(host.to_owned())); } - if let Some(host) = parts.uri.host() { - return Ok(Host(host.to_owned())); + if let Some(authority) = parts.uri.authority() { + return Ok(Host(parse_authority(authority).to_owned())); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) @@ -76,11 +80,18 @@ fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { }) } +fn parse_authority(auth: &Authority) -> &str { + auth.as_str() + .rsplit('@') + .next() + .expect("split always has at least 1 item") +} + #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router}; - use http::header::HeaderName; + use http::{header::HeaderName, Request}; fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { @@ -130,8 +141,14 @@ mod tests { #[crate::test] async fn uri_host() { - let host = test_client().get("/").await.text().await; - assert!(host.contains("127.0.0.1")); + let mut parts = Request::new(()).into_parts().0; + parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "127.0.0.1:1234"); + + parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "[::1]:456"); } #[test]