diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d5346e..4901ae4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,14 +9,14 @@ on: jobs: xenial: container: - image: swift:5.2-xenial + image: swift:5.5-xenial runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - run: swift test --enable-test-discovery --enable-code-coverage bionic: container: - image: swift:5.4-bionic + image: swift:5.5-bionic runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/Package.swift b/Package.swift index 0dec321..bf75b0e 100644 --- a/Package.swift +++ b/Package.swift @@ -1,10 +1,10 @@ -// swift-tools-version:5.2 +// swift-tools-version:5.5 import PackageDescription let package = Package( name: "LeafErrorMiddleware", platforms: [ - .macOS(.v10_15), + .macOS(.v12), ], products: [ .library(name: "LeafErrorMiddleware", targets: ["LeafErrorMiddleware"]), diff --git a/Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift b/Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift index 33c237e..a9c03f9 100644 --- a/Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift +++ b/Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift @@ -1,71 +1,70 @@ import Vapor /// Captures all errors and transforms them into an internal server error. -public final class LeafErrorMiddleware: Middleware { +public final class LeafErrorMiddleware: AsyncMiddleware { + let contextGenerator: (HTTPStatus, Error, Request) async throws -> T - let contextGenerator: ((HTTPStatus, Error, Request) -> EventLoopFuture) - - public init(contextGenerator: @escaping ((HTTPStatus, Error, Request) -> EventLoopFuture)) { + public init(contextGenerator: @escaping ((HTTPStatus, Error, Request) async throws -> T)) { self.contextGenerator = contextGenerator } - + /// See `Middleware.respond` - public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture { - return next.respond(to: request).flatMap { res in + public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response { + do { + let res = try await next.respond(to: request) if res.status.code >= HTTPResponseStatus.badRequest.code { - return self.handleError(for: request, status: res.status, error: Abort(res.status)) + return try await handleError(for: request, status: res.status, error: Abort(res.status)) } else { - return res.encodeResponse(for: request) + return try await res.encodeResponse(for: request) } - }.flatMapError { error in + } catch { request.logger.report(error: error) - switch (error) { - case let abort as AbortError: - guard - abort.status.representsError + switch error { + case let abort as AbortError: + guard + abort.status.representsError else { if let location = abort.headers[.location].first { - return request.eventLoop.future(request.redirect(to: location)) + return request.redirect(to: location) } else { - return self.handleError(for: request, status: abort.status, error: error) + return try await handleError(for: request, status: abort.status, error: error) } - } - return self.handleError(for: request, status: abort.status, error: error) - default: - return self.handleError(for: request, status: .internalServerError, error: error) + } + return try await handleError(for: request, status: abort.status, error: error) + default: + return try await handleError(for: request, status: .internalServerError, error: error) } } } - - private func handleError(for req: Request, status: HTTPStatus, error: Error) -> EventLoopFuture { + + private func handleError(for request: Request, status: HTTPStatus, error: Error) async throws -> Response { if status == .notFound { - return contextGenerator(status, error, req).flatMap { context in - return req.view.render("404", context).encodeResponse(for: req).map { res in - res.status = status - return res - } - }.flatMapError { newError in - return self.renderServerErrorPage(for: status, request: req, error: newError) + do { + let context = try await contextGenerator(status, error, request) + let res = try await request.view.render("404", context).encodeResponse(for: request).get() + res.status = status + return res + } catch { + return try await renderServerErrorPage(for: status, request: request, error: error) } } - return renderServerErrorPage(for: status, request: req, error: error) + return try await renderServerErrorPage(for: status, request: request, error: error) } - - private func renderServerErrorPage(for status: HTTPStatus, request: Request, error: Error) -> EventLoopFuture { - return contextGenerator(status, error, request).flatMap { context in - request.logger.error("Internal server error. Status: \(status.code) - path: \(request.url)") - return request.view.render("serverError", context).encodeResponse(for: request).map { res in - res.status = status - return res - } - }.flatMapError { error -> EventLoopFuture in + + private func renderServerErrorPage(for status: HTTPStatus, request: Request, error: Error) async throws -> Response { + do { + let context = try await contextGenerator(status, error, request) + request.logger.error("Internal server error. Status: \(status.code) - path: \(request.url)") + let res = try await request.view.render("serverError", context).encodeResponse(for: request).get() + res.status = status + return res + } catch { let body = "

Internal Error

There was an internal error. Please try again later.

" request.logger.error("Failed to render custom error page - \(error)") - return body.encodeResponse(for: request).map { res in - res.status = status - res.headers.replaceOrAdd(name: .contentType, value: "text/html; charset=utf-8") - return res - } + let res = try await body.encodeResponse(for: request) + res.status = status + res.headers.replaceOrAdd(name: .contentType, value: "text/html; charset=utf-8") + return res } } } diff --git a/Sources/LeafErrorMiddleware/LeafErrorMiddlewareDefaultGenerator.swift b/Sources/LeafErrorMiddleware/LeafErrorMiddlewareDefaultGenerator.swift index 688c843..8fb5085 100644 --- a/Sources/LeafErrorMiddleware/LeafErrorMiddlewareDefaultGenerator.swift +++ b/Sources/LeafErrorMiddleware/LeafErrorMiddlewareDefaultGenerator.swift @@ -1,25 +1,7 @@ import Vapor -@available(*, deprecated, renamed: "LeafErrorMiddlewareDefaultGenerator") -public enum LeafErorrMiddlewareDefaultGenerator { - static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) -> EventLoopFuture { - let reason: String? - if let abortError = error as? AbortError { - reason = abortError.reason - } else { - reason = nil - } - let context = DefaultContext(status: status.code.description, statusMessage: status.reasonPhrase, reason: reason) - return req.eventLoop.future(context ) - } - - public static func build() -> LeafErrorMiddleware { - LeafErrorMiddleware(contextGenerator: generate) - } -} - public enum LeafErrorMiddlewareDefaultGenerator { - static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) -> EventLoopFuture { + static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) async throws -> DefaultContext { let reason: String? if let abortError = error as? AbortError { reason = abortError.reason @@ -27,7 +9,7 @@ public enum LeafErrorMiddlewareDefaultGenerator { reason = nil } let context = DefaultContext(status: status.code.description, statusMessage: status.reasonPhrase, reason: reason) - return req.eventLoop.future(context ) + return context } public static func build() -> LeafErrorMiddleware { diff --git a/Tests/LeafErrorMiddlewareTests/CustomGeneratorTests.swift b/Tests/LeafErrorMiddlewareTests/CustomGeneratorTests.swift index 17fb6c5..5307e5f 100644 --- a/Tests/LeafErrorMiddlewareTests/CustomGeneratorTests.swift +++ b/Tests/LeafErrorMiddlewareTests/CustomGeneratorTests.swift @@ -1,84 +1,76 @@ -import XCTest import LeafErrorMiddleware -import Vapor @testable import Logging +import Vapor +import XCTest struct AContext: Encodable { let trigger: Bool } class CustomGeneratorTests: XCTestCase { - // MARK: - Properties + var app: Application! var viewRenderer: ThrowingViewRenderer! var logger = CapturingLogger() var eventLoopGroup: EventLoopGroup! // MARK: - Overrides + override func setUpWithError() throws { eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) viewRenderer = ThrowingViewRenderer(eventLoop: eventLoopGroup.next()) LoggingSystem.bootstrapInternal { _ in - return self.logger + self.logger } app = Application(.testing, .shared(eventLoopGroup)) app.views.use { _ in - return self.viewRenderer + self.viewRenderer } func routes(_ router: RoutesBuilder) throws { + router.get("ok") { _ in + "ok" + } - router.get("ok") { req in - return "ok" + router.get("404") { _ -> HTTPStatus in + .notFound } - router.get("404") { req -> HTTPStatus in - return .notFound + router.get("403") { _ -> Response in + throw Abort(.forbidden) } - router.get("serverError") { req -> EventLoopFuture in + router.get("serverError") { _ -> Response in throw Abort(.internalServerError) } - router.get("unknownError") { req -> EventLoopFuture in + router.get("unknownError") { _ -> Response in throw TestError() } - router.get("unauthorized") { req -> EventLoopFuture in + router.get("unauthorized") { _ -> Response in throw Abort(.unauthorized) } - router.get("future404") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort(.notFound)) - } - - router.get("future403") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort(.forbidden)) - } - - router.get("future303") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort.redirect(to: "ok")) - } - - router.get("future404NoAbort") { req -> EventLoopFuture in - return req.eventLoop.future(.notFound) + router.get("303") { _ -> Response in + throw Abort.redirect(to: "ok") } - router.get("404withReason") { req -> HTTPStatus in + router.get("404withReason") { _ -> HTTPStatus in throw Abort(.notFound, reason: "Could not find it") } - router.get("500withReason") { req -> HTTPStatus in + router.get("500withReason") { _ -> HTTPStatus in throw Abort(.badGateway, reason: "I messed up") } } try routes(app) - let leafMiddleware = LeafErrorMiddleware() { status, error, req -> EventLoopFuture in - return req.eventLoop.future(AContext(trigger: true)) + let leafMiddleware = LeafErrorMiddleware { status, error, req async throws -> AContext in + AContext(trigger: true) } app.middleware.use(leafMiddleware) } @@ -139,36 +131,24 @@ class CustomGeneratorTests: XCTestCase { XCTAssertEqual(viewRenderer.leafPath, "serverError") } - func testNonAbort404IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/404") - XCTAssertEqual(response.status, .notFound) - XCTAssertEqual(viewRenderer.leafPath, "404") - } - - func testThatFuture404IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future404") - XCTAssertEqual(response.status, .notFound) - XCTAssertEqual(viewRenderer.leafPath, "404") + func testThatRedirectIsNotCaught() throws { + let response = try app.getResponse(to: "/303") + XCTAssertEqual(response.status, .seeOther) + XCTAssertEqual(response.headers[.location].first, "ok") } - func testFutureNonAbort404IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future404NoAbort") + func testNonAbort404IsCaughtCorrectly() throws { + let response = try app.getResponse(to: "/404") XCTAssertEqual(response.status, .notFound) XCTAssertEqual(viewRenderer.leafPath, "404") } - func testThatFuture403IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future403") + func testThat403IsCaughtCorrectly() throws { + let response = try app.getResponse(to: "/403") XCTAssertEqual(response.status, .forbidden) XCTAssertEqual(viewRenderer.leafPath, "serverError") } - func testThatRedirectIsNotCaught() throws { - let response = try app.getResponse(to: "/future303") - XCTAssertEqual(response.status, .seeOther) - XCTAssertEqual(response.headers[.location].first, "ok") - } - func testContextGeneratedOn404Page() throws { let response = try app.getResponse(to: "/404") XCTAssertEqual(response.status, .notFound) @@ -189,20 +169,19 @@ class CustomGeneratorTests: XCTestCase { app.shutdown() app = Application(.testing, .shared(eventLoopGroup)) app.views.use { _ in - return self.viewRenderer + self.viewRenderer } - let leafErrorMiddleware = LeafErrorMiddleware() { status, error, req -> EventLoopFuture in - return req.eventLoop.future(error: Abort(.internalServerError)) - + let leafErrorMiddleware = LeafErrorMiddleware { _, _, _ -> AContext in + throw Abort(.internalServerError) } app.middleware = .init() app.middleware.use(leafErrorMiddleware) - app.get("404") { req -> EventLoopFuture in - req.eventLoop.makeFailedFuture(Abort(.notFound)) + app.get("404") { _ async throws -> Response in + throw Abort(.notFound) } - app.get("500") { req -> EventLoopFuture in - req.eventLoop.makeFailedFuture(Abort(.internalServerError)) + app.get("500") { _ async throws -> Response in + throw Abort(.internalServerError) } let response404 = try app.getResponse(to: "404") @@ -213,6 +192,4 @@ class CustomGeneratorTests: XCTestCase { XCTAssertEqual(response500.status, .internalServerError) XCTAssertNil(viewRenderer.leafPath) } - - } diff --git a/Tests/LeafErrorMiddlewareTests/DefaultLeafErrorMiddlewareTests.swift b/Tests/LeafErrorMiddlewareTests/DefaultLeafErrorMiddlewareTests.swift index 658f5d9..1766b18 100644 --- a/Tests/LeafErrorMiddlewareTests/DefaultLeafErrorMiddlewareTests.swift +++ b/Tests/LeafErrorMiddlewareTests/DefaultLeafErrorMiddlewareTests.swift @@ -1,118 +1,116 @@ -import XCTest import LeafErrorMiddleware -import Vapor @testable import Logging +import Vapor +import XCTest class DefaultLeafErrorMiddlewareTests: XCTestCase { - // MARK: - Properties + var app: Application! var viewRenderer: ThrowingViewRenderer! var logger = CapturingLogger() var eventLoopGroup: EventLoopGroup! - + // MARK: - Overrides + override func setUpWithError() throws { eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) viewRenderer = ThrowingViewRenderer(eventLoop: eventLoopGroup.next()) LoggingSystem.bootstrapInternal { _ in - return self.logger + self.logger } app = Application(.testing, .shared(eventLoopGroup)) app.views.use { _ in - return self.viewRenderer + self.viewRenderer } func routes(_ router: RoutesBuilder) throws { + router.get("ok") { _ in + "ok" + } - router.get("ok") { req in - return "ok" + router.get("404") { _ -> HTTPStatus in + throw Abort(.notFound) } - - router.get("404") { req -> HTTPStatus in - return .notFound + + router.get("403") { _ -> Response in + throw Abort(.forbidden) } - router.get("serverError") { req -> EventLoopFuture in + router.get("serverError") { _ -> Response in throw Abort(.internalServerError) } - router.get("unknownError") { req -> EventLoopFuture in + router.get("unknownError") { _ -> Response in throw TestError() } - router.get("unauthorized") { req -> EventLoopFuture in + router.get("unauthorized") { _ -> Response in throw Abort(.unauthorized) } - - router.get("future404") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort(.notFound)) - } - - router.get("future403") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort(.forbidden)) - } - router.get("future303") { req -> EventLoopFuture in - return req.eventLoop.future(error: Abort.redirect(to: "ok")) + router.get("303") { _ -> Response in + throw Abort.redirect(to: "ok") } - - router.get("future404NoAbort") { req -> EventLoopFuture in - return req.eventLoop.future(.notFound) - } - - router.get("404withReason") { req -> HTTPStatus in + + router.get("404withReason") { _ -> HTTPStatus in throw Abort(.notFound, reason: "Could not find it") } - - router.get("500withReason") { req -> HTTPStatus in + + router.get("500withReason") { _ -> HTTPStatus in throw Abort(.badGateway, reason: "I messed up") } } try routes(app) - + app.middleware.use(LeafErrorMiddlewareDefaultGenerator.build()) } - + override func tearDownWithError() throws { app.shutdown() try eventLoopGroup.syncShutdownGracefully() } - + // MARK: - Tests - + + func testThatRedirectIsNotCaught() throws { + let response = try app.getResponse(to: "/303") + XCTAssertEqual(response.status, .seeOther) + XCTAssertEqual(response.headers[.location].first, "ok") + } + func testThatValidEndpointWorks() throws { let response = try app.getResponse(to: "/ok") XCTAssertEqual(response.status, .ok) } - + func testThatRequestingInvalidEndpointReturns404View() throws { let response = try app.getResponse(to: "/unknown") XCTAssertEqual(response.status, .notFound) XCTAssertEqual(viewRenderer.leafPath, "404") } - + func testThatRequestingPageThatCausesAServerErrorReturnsServerErrorView() throws { let response = try app.getResponse(to: "/serverError") XCTAssertEqual(response.status, .internalServerError) XCTAssertEqual(viewRenderer.leafPath, "serverError") } - + func testThatErrorGetsLogged() throws { _ = try app.getResponse(to: "/serverError") XCTAssertNotNil(logger.message) XCTAssertEqual(logger.logLevelUsed, .error) } - + func testThatMiddlewareFallsBackIfViewRendererFails() throws { viewRenderer.shouldThrow = true let response = try app.getResponse(to: "/serverError") XCTAssertEqual(response.status, .internalServerError) XCTAssertEqual(response.body.string, "

Internal Error

There was an internal error. Please try again later.

") } - + func testThatMiddlewareFallsBackIfViewRendererFailsFor404() throws { viewRenderer.shouldThrow = true let response = try app.getResponse(to: "/unknown") @@ -125,13 +123,13 @@ class DefaultLeafErrorMiddlewareTests: XCTestCase { _ = try app.getResponse(to: "/serverError") XCTAssertTrue(logger.message?.starts(with: "Failed to render custom error page") ?? false) } - + func testThatRandomErrorGetsReturnedAsServerError() throws { let response = try app.getResponse(to: "/unknownError") XCTAssertEqual(response.status, .internalServerError) XCTAssertEqual(viewRenderer.leafPath, "serverError") } - + func testThatUnauthorisedIsPassedThroughToServerErrorPage() throws { let response = try app.getResponse(to: "/unauthorized") XCTAssertEqual(response.status, .unauthorized) @@ -143,49 +141,31 @@ class DefaultLeafErrorMiddlewareTests: XCTestCase { XCTAssertEqual(context.status, "401") XCTAssertEqual(context.statusMessage, "Unauthorized") } - + func testNonAbort404IsCaughtCorrectly() throws { let response = try app.getResponse(to: "/404") XCTAssertEqual(response.status, .notFound) XCTAssertEqual(viewRenderer.leafPath, "404") } - - func testThatFuture404IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future404") - XCTAssertEqual(response.status, .notFound) - XCTAssertEqual(viewRenderer.leafPath, "404") - } - - func testFutureNonAbort404IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future404NoAbort") - XCTAssertEqual(response.status, .notFound) - XCTAssertEqual(viewRenderer.leafPath, "404") - } - - func testThatFuture403IsCaughtCorrectly() throws { - let response = try app.getResponse(to: "/future403") + + func testThat403IsCaughtCorrectly() throws { + let response = try app.getResponse(to: "/403") XCTAssertEqual(response.status, .forbidden) XCTAssertEqual(viewRenderer.leafPath, "serverError") } - func testThatRedirectIsNotCaught() throws { - let response = try app.getResponse(to: "/future303") - XCTAssertEqual(response.status, .seeOther) - XCTAssertEqual(response.headers[.location].first, "ok") - } - func testAddingMiddlewareToRouteGroup() throws { app.shutdown() app = Application(.testing, .shared(eventLoopGroup)) app.views.use { _ in - return self.viewRenderer + self.viewRenderer } let middlewareGroup = app.grouped(LeafErrorMiddlewareDefaultGenerator.build()) - middlewareGroup.get("404") { req -> EventLoopFuture in - req.eventLoop.makeFailedFuture(Abort(.notFound)) + middlewareGroup.get("404") { _ async throws -> Response in + throw Abort(.notFound) } - middlewareGroup.get("ok") { req in - return "OK" + middlewareGroup.get("ok") { _ in + "OK" } let validResponse = try app.getResponse(to: "ok") XCTAssertEqual(validResponse.status, .ok) @@ -193,7 +173,7 @@ class DefaultLeafErrorMiddlewareTests: XCTestCase { XCTAssertEqual(response.status, .notFound) XCTAssertEqual(viewRenderer.leafPath, "404") } - + func testReasonIsPassedThroughTo404Page() throws { let response = try app.getResponse(to: "/404withReason") XCTAssertEqual(response.status, .notFound) @@ -204,7 +184,7 @@ class DefaultLeafErrorMiddlewareTests: XCTestCase { } XCTAssertEqual(context.reason, "Could not find it") } - + func testReasonIsPassedThroughTo500Page() throws { let response = try app.getResponse(to: "/500withReason") XCTAssertEqual(response.status, .badGateway) @@ -219,7 +199,7 @@ class DefaultLeafErrorMiddlewareTests: XCTestCase { extension Application { func getResponse(to path: String) throws -> Response { - let request = Request(application: self, method: .GET, url: URI(path: path), on: self.eventLoopGroup.next()) - return try self.responder.respond(to: request).wait() + let request = Request(application: self, method: .GET, url: URI(path: path), on: eventLoopGroup.next()) + return try responder.respond(to: request).wait() } } diff --git a/Tests/LeafErrorMiddlewareTests/Fakes/CapturingLogger.swift b/Tests/LeafErrorMiddlewareTests/Fakes/CapturingLogger.swift index f98722c..297b501 100644 --- a/Tests/LeafErrorMiddlewareTests/Fakes/CapturingLogger.swift +++ b/Tests/LeafErrorMiddlewareTests/Fakes/CapturingLogger.swift @@ -1,7 +1,6 @@ import Vapor class CapturingLogger: LogHandler { - subscript(metadataKey key: String) -> Logger.Metadata.Value? { get { return self.metadata[key] } set { self.metadata[key] = newValue } diff --git a/Tests/LeafErrorMiddlewareTests/Fakes/ThrowingViewRenderer.swift b/Tests/LeafErrorMiddlewareTests/Fakes/ThrowingViewRenderer.swift index 4d68ee5..2e5628e 100644 --- a/Tests/LeafErrorMiddlewareTests/Fakes/ThrowingViewRenderer.swift +++ b/Tests/LeafErrorMiddlewareTests/Fakes/ThrowingViewRenderer.swift @@ -1,7 +1,6 @@ import Vapor class ThrowingViewRenderer: ViewRenderer { - var shouldCache = false var eventLoop: EventLoop var shouldThrow = false @@ -10,12 +9,12 @@ class ThrowingViewRenderer: ViewRenderer { self.eventLoop = eventLoop } - private(set) var capturedContext: Encodable? = nil - private(set) var leafPath: String? = nil - func render(_ name: String, _ context: E) -> EventLoopFuture where E : Encodable { + private(set) var capturedContext: Encodable? + private(set) var leafPath: String? + func render(_ name: String, _ context: E) -> EventLoopFuture where E: Encodable { self.capturedContext = context self.leafPath = name - if shouldThrow { + if self.shouldThrow { return self.eventLoop.makeFailedFuture(TestError()) } let response = "Test" @@ -23,7 +22,7 @@ class ThrowingViewRenderer: ViewRenderer { byteBuffer.writeString(response) return self.eventLoop.future(View(data: byteBuffer)) } - + func `for`(_ request: Request) -> ViewRenderer { return self }