Skip to content

Commit

Permalink
Ensure CORS middleware sets headers when an error is thrown
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed May 22, 2024
1 parent 347b8db commit 0288162
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
41 changes: 30 additions & 11 deletions Sources/Hummingbird/Middleware/CORSMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ public struct CORSMiddleware<Context: BaseRequestContext>: RouterMiddleware {
if request.method == .options {
// if request is OPTIONS then return CORS headers and skip the rest of the middleware chain
var headers: HTTPFields = [
.accessControlAllowOrigin: allowOrigin.value(for: request) ?? "",
.accessControlAllowHeaders: self.allowHeaders,
.accessControlAllowMethods: self.allowMethods,
]
headers[.accessControlAllowHeaders] = self.allowHeaders
headers[.accessControlAllowMethods] = self.allowMethods
if let allowOrigin = allowOrigin.value(for: request) {
headers[.accessControlAllowOrigin] = allowOrigin
}
if self.allowCredentials {
headers[.accessControlAllowCredentials] = "true"
}
Expand All @@ -113,15 +115,32 @@ public struct CORSMiddleware<Context: BaseRequestContext>: RouterMiddleware {
return Response(status: .noContent, headers: headers, body: .init())
} else {
// if not OPTIONS then run rest of middleware chain and add origin value at the end
var response = try await next(request, context)
response.headers[.accessControlAllowOrigin] = self.allowOrigin.value(for: request) ?? ""
if self.allowCredentials {
response.headers[.accessControlAllowCredentials] = "true"
}
if case .originBased = self.allowOrigin {
response.headers[.vary] = "Origin"
do {
var response = try await next(request, context)
response.headers[.accessControlAllowOrigin] = self.allowOrigin.value(for: request)
if self.allowCredentials {
response.headers[.accessControlAllowCredentials] = "true"
}
if case .originBased = self.allowOrigin {
response.headers[.vary] = "Origin"
}
return response
} catch {
// If next throws an error add headers to error
var additionalHeaders = HTTPFields()
additionalHeaders[.accessControlAllowOrigin] = self.allowOrigin.value(for: request)
if self.allowCredentials {
additionalHeaders[.accessControlAllowCredentials] = "true"
}
if case .originBased = self.allowOrigin {
additionalHeaders[.vary] = "Origin"
}
throw EditedHTTPError(
originalError: error,
additionalHeaders: additionalHeaders,
context: context
)
}
return response
}
}
}
39 changes: 39 additions & 0 deletions Sources/Hummingbird/Server/EditedHTTPError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import HummingbirdCore

/// Error generated from another error that adds additional headers to the response
struct EditedHTTPError: HTTPResponseError {
let status: HTTPResponse.Status
let headers: HTTPFields
let body: ByteBuffer?

init(originalError: Error, additionalHeaders: HTTPFields, context: some BaseRequestContext) {
if let httpError = originalError as? HTTPResponseError {
self.status = httpError.status
self.headers = httpError.headers + additionalHeaders
self.body = httpError.body(allocator: context.allocator)
} else {
self.status = .internalServerError
self.headers = additionalHeaders
self.body = nil
}
}

func body(allocator: NIOCore.ByteBufferAllocator) -> NIOCore.ByteBuffer? {
return self.body
}
}
12 changes: 12 additions & 0 deletions Tests/HummingbirdTests/MiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,18 @@ final class MiddlewareTests: XCTestCase {
}
}

func testCORSHeadersAndErrors() async throws {
let router = Router()
router.middlewares.add(CORSMiddleware())
let app = Application(responder: router.buildResponder())
try await app.test(.router) { client in
try await client.execute(uri: "/hello", method: .get, headers: [.origin: "foo.com"]) { response in
// headers come back in opposite order as middleware is applied to responses in that order
XCTAssertEqual(response.headers[.accessControlAllowOrigin], "foo.com")
}
}
}

func testLogRequestMiddleware() async throws {
let logAccumalator = TestLogHandler.LogAccumalator()
let router = Router()
Expand Down

0 comments on commit 0288162

Please sign in to comment.