diff --git a/Sources/Hummingbird/Middleware/CORSMiddleware.swift b/Sources/Hummingbird/Middleware/CORSMiddleware.swift index 7d58cc9a5..c401e9231 100644 --- a/Sources/Hummingbird/Middleware/CORSMiddleware.swift +++ b/Sources/Hummingbird/Middleware/CORSMiddleware.swift @@ -93,10 +93,12 @@ public struct CORSMiddleware: RouterMiddleware { if request.method == .options { // if request is OPTIONS then return CORS headers and skip the rest of the middleware chain var headers: HTTPFields = [ - .accessControlAllowOrigin: allowOrigin.value(for: request) ?? "", + .accessControlAllowHeaders: self.allowHeaders, + .accessControlAllowMethods: self.allowMethods, ] - headers[.accessControlAllowHeaders] = self.allowHeaders - headers[.accessControlAllowMethods] = self.allowMethods + if let allowOrigin = allowOrigin.value(for: request) { + headers[.accessControlAllowOrigin] = allowOrigin + } if self.allowCredentials { headers[.accessControlAllowCredentials] = "true" } @@ -113,15 +115,32 @@ public struct CORSMiddleware: RouterMiddleware { return Response(status: .noContent, headers: headers, body: .init()) } else { // if not OPTIONS then run rest of middleware chain and add origin value at the end - var response = try await next(request, context) - response.headers[.accessControlAllowOrigin] = self.allowOrigin.value(for: request) ?? "" - if self.allowCredentials { - response.headers[.accessControlAllowCredentials] = "true" - } - if case .originBased = self.allowOrigin { - response.headers[.vary] = "Origin" + do { + var response = try await next(request, context) + response.headers[.accessControlAllowOrigin] = self.allowOrigin.value(for: request) + if self.allowCredentials { + response.headers[.accessControlAllowCredentials] = "true" + } + if case .originBased = self.allowOrigin { + response.headers[.vary] = "Origin" + } + return response + } catch { + // If next throws an error add headers to error + var additionalHeaders = HTTPFields() + additionalHeaders[.accessControlAllowOrigin] = self.allowOrigin.value(for: request) + if self.allowCredentials { + additionalHeaders[.accessControlAllowCredentials] = "true" + } + if case .originBased = self.allowOrigin { + additionalHeaders[.vary] = "Origin" + } + throw EditedHTTPError( + originalError: error, + additionalHeaders: additionalHeaders, + context: context + ) } - return response } } } diff --git a/Sources/Hummingbird/Server/EditedHTTPError.swift b/Sources/Hummingbird/Server/EditedHTTPError.swift new file mode 100644 index 000000000..21b239714 --- /dev/null +++ b/Sources/Hummingbird/Server/EditedHTTPError.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import HummingbirdCore + +/// Error generated from another error that adds additional headers to the response +struct EditedHTTPError: HTTPResponseError { + let status: HTTPResponse.Status + let headers: HTTPFields + let body: ByteBuffer? + + init(originalError: Error, additionalHeaders: HTTPFields, context: some BaseRequestContext) { + if let httpError = originalError as? HTTPResponseError { + self.status = httpError.status + self.headers = httpError.headers + additionalHeaders + self.body = httpError.body(allocator: context.allocator) + } else { + self.status = .internalServerError + self.headers = additionalHeaders + self.body = nil + } + } + + func body(allocator: NIOCore.ByteBufferAllocator) -> NIOCore.ByteBuffer? { + return self.body + } +} diff --git a/Tests/HummingbirdTests/MiddlewareTests.swift b/Tests/HummingbirdTests/MiddlewareTests.swift index a9d961b21..82ca8b134 100644 --- a/Tests/HummingbirdTests/MiddlewareTests.swift +++ b/Tests/HummingbirdTests/MiddlewareTests.swift @@ -242,6 +242,18 @@ final class MiddlewareTests: XCTestCase { } } + func testCORSHeadersAndErrors() async throws { + let router = Router() + router.middlewares.add(CORSMiddleware()) + let app = Application(responder: router.buildResponder()) + try await app.test(.router) { client in + try await client.execute(uri: "/hello", method: .get, headers: [.origin: "foo.com"]) { response in + // headers come back in opposite order as middleware is applied to responses in that order + XCTAssertEqual(response.headers[.accessControlAllowOrigin], "foo.com") + } + } + } + func testLogRequestMiddleware() async throws { let logAccumalator = TestLogHandler.LogAccumalator() let router = Router()