From 2a72400a31b56e5224caf1eaa5113cfb418ec311 Mon Sep 17 00:00:00 2001 From: Lukas Bindreiter Date: Fri, 17 Jan 2025 17:50:23 +0100 Subject: [PATCH] Support multiple proxy servers in Forwarded header parsing (#782) * Support multiple proxy servers in Forwarded header parsing * Update CHANGELOG * use regex and simplify * use integer regex * Use compiled regexes --------- Co-authored-by: vincentsarago --- CHANGES.md | 6 ++- .../api/stac_fastapi/api/middleware.py | 41 +++++++++---------- stac_fastapi/api/tests/test_middleware.py | 28 +++++++++++++ 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 67cee3ea3..4290fa1d5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,11 +2,15 @@ ## [Unreleased] -## Changed +### Changed * use `string` type instead of python `datetime.datetime` for datetime parameter in `BaseSearchGetRequest`, `ItemCollectionUri` and `BaseCollectionSearchGetRequest` GET models * rename `filter` to `filter_expr` for `FilterExtensionGetRequest` and `FilterExtensionPostRequest` attributes to avoid conflict with python filter method +### Fixed + +* Support multiple proxy servers in the `forwarded` header in `ProxyHeaderMiddleware` ([#782](https://github.com/stac-utils/stac-fastapi/pull/782)) + ## [3.0.5] - 2025-01-10 ### Removed diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 2ba3ef570..0b1192317 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,5 +1,6 @@ """Api middleware.""" +import contextlib import re import typing from http.client import HTTP_PORT, HTTPS_PORT @@ -44,6 +45,10 @@ def __init__( ) +_PROTO_HEADER_REGEX = re.compile(r"proto=(?Phttp(s)?)") +_HOST_HEADER_REGEX = re.compile(r"host=(?P[\w.-]+)(:(?P\d{1,5}))?") + + class ProxyHeaderMiddleware: """Account for forwarding headers when deriving base URL. @@ -68,11 +73,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: proto == "https" and port != HTTPS_PORT ): port_suffix = f":{port}" + scope["headers"] = self._replace_header_value_by_name( scope, "host", f"{domain}{port_suffix}", ) + await self.app(scope, receive, send) def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: @@ -87,31 +94,23 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: else: domain = header_host_parts[0] port = None - forwarded = self._get_header_value_by_name(scope, "forwarded") - if forwarded is not None: - parts = forwarded.split(";") - for part in parts: - if len(part) > 0 and re.search("=", part): - key, value = part.split("=") - if key == "proto": - proto = value - elif key == "host": - host_parts = value.split(":") - domain = host_parts[0] - try: - port = int(host_parts[1]) if len(host_parts) == 2 else None - except ValueError: - # ignore ports that are not valid integers - pass + + if forwarded := self._get_header_value_by_name(scope, "forwarded"): + for proxy in forwarded.split(","): + if (proto_expr := _PROTO_HEADER_REGEX.search(proxy)) and ( + host_expr := _HOST_HEADER_REGEX.search(proxy) + ): + proto = proto_expr.group("proto") + domain = host_expr.group("host") + port_str = host_expr.group("port") # None if not present in the match + else: domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain) proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto) port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port) - try: - port = int(port_str) if port_str is not None else None - except ValueError: - # ignore ports that are not valid integers - pass + + with contextlib.suppress(ValueError): # ignore ports that are not valid integers + port = int(port_str) if port_str is not None else port return (proto, domain, port) diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py index 00e7f8038..45f65516c 100644 --- a/stac_fastapi/api/tests/test_middleware.py +++ b/stac_fastapi/api/tests/test_middleware.py @@ -155,6 +155,34 @@ def test_replace_header_value_by_name( }, ("https", "test", 1234), ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [ + ( + b"forwarded", + # two proxy servers added an entry, we want to use the last one + b"proto=https;host=test:1234,proto=https;host=second-server:1111", + ) + ], + }, + ("https", "second-server", 1111), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [ + ( + b"forwarded", + # check when host and port are inverted + b"host=test:1234;proto=https", + ) + ], + }, + ("https", "test", 1234), + ), ], ) def test_get_forwarded_url_parts(