diff --git a/Sources/Hummingbird/Middleware/MiddlewareModule/MiddlewareStack.swift b/Sources/Hummingbird/Middleware/MiddlewareModule/MiddlewareStack.swift index 2d07f3cd..fa333738 100644 --- a/Sources/Hummingbird/Middleware/MiddlewareModule/MiddlewareStack.swift +++ b/Sources/Hummingbird/Middleware/MiddlewareModule/MiddlewareStack.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +// MARK: - Middleware2 + public struct _Middleware2: MiddlewareProtocol where M0.Input == M1.Input, M0.Context == M1.Context, M0.Output == M1.Output { public typealias Input = M0.Input public typealias Output = M0.Output @@ -34,6 +36,10 @@ public struct _Middleware2: Midd } } +extension _Middleware2: RouterMiddleware where M0.Input == Request, M0.Output == Response {} + +// MARK: - MiddlewareFixedTypeBuilder + /// Middleware stack result builder /// /// Generates a middleware stack from the elements inside the result builder. The input, diff --git a/Sources/HummingbirdRouter/RouterController.swift b/Sources/HummingbirdRouter/RouterController.swift new file mode 100644 index 00000000..259ae41f --- /dev/null +++ b/Sources/HummingbirdRouter/RouterController.swift @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Hummingbird + +// MARK: - RouterController + +/// A type that represents part of your app's middleware and routes +/// +/// You create custom controllers by declaring types that conform to the `RouterController` +/// protocol. Implement the required ``RouterController/body-swift.property`` computed +/// property to provide the content for your custom controller. +/// +/// struct MyController: RouterController { +/// typealias Context = BasicRouterRequestContext +/// +/// var body: some RouterMiddleware { +/// Get("foo") { _,_ in "foo" } +/// } +/// } +/// +/// Assemble the controller's body by combining one or more of the built-in controllers or middleware. +/// provided by Hummingbird, plus other custom controllers that you define, into a hierarchy of controllers. +public protocol RouterController { + associatedtype Context + associatedtype Body: RouterMiddleware + @MiddlewareFixedTypeBuilder var body: Body { get } +} + +// MARK: MiddlewareFixedTypeBuilder + RouterController Builders + +extension MiddlewareFixedTypeBuilder { + public static func buildExpression(_ c0: C0) -> C0.Body where C0.Body.Input == Input, C0.Body.Output == Output, C0.Body.Context == Context { + return c0.body + } +} diff --git a/Tests/HummingbirdRouterTests/ControllerTests.swift b/Tests/HummingbirdRouterTests/ControllerTests.swift new file mode 100644 index 00000000..dd3b326e --- /dev/null +++ b/Tests/HummingbirdRouterTests/ControllerTests.swift @@ -0,0 +1,152 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2021-2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Hummingbird +import HummingbirdRouter +import HummingbirdTesting +import XCTest + +final class ControllerTests: XCTestCase { + func testRouterControllerWithSingleRoute() async throws { + struct TestController: RouterController { + typealias Context = BasicRouterRequestContext + var body: some RouterMiddleware { + Get("foo") { _, _ in "foo" } + } + } + + let router = RouterBuilder(context: BasicRouterRequestContext.self) { + TestController() + } + + let app = Application(responder: router) + try await app.test(.router) { client in + try await client.execute(uri: "/foo", method: .get) { + XCTAssertEqual(String(buffer: $0.body), "foo") + } + } + } + + func testRouterControllerWithMultipleRoutes() async throws { + struct TestController: RouterController { + typealias Context = BasicRouterRequestContext + var body: some RouterMiddleware { + Get("foo") { _, _ in "foo" } + Get("bar") { _, _ in "bar" } + } + } + + let router = RouterBuilder(context: BasicRouterRequestContext.self) { + TestController() + } + + let app = Application(responder: router) + try await app.test(.router) { client in + try await client.execute(uri: "/foo", method: .get) { + XCTAssertEqual(String(buffer: $0.body), "foo") + } + + try await client.execute(uri: "/bar", method: .get) { + XCTAssertEqual(String(buffer: $0.body), "bar") + } + } + } + + func testRouterControllerWithGenericChildren() async throws { + struct ChildController: RouterController { + typealias Context = BasicRouterRequestContext + let name: String + var body: some RouterMiddleware { + Get("child_\(self.name)") { _, _ in "child_\(self.name)" } + } + } + + struct ParentController: RouterController where Child.Context == Context { + var child: Child + + init(@MiddlewareFixedTypeBuilder _ child: () -> Child) { + self.child = child() + } + + var body: some RouterMiddleware { + RouteGroup("parent") { + self.child + } + } + } + + let router = RouterBuilder(context: BasicRouterRequestContext.self) { + ParentController { + Get("child_a") { _, _ in "child_a" } + ChildController(name: "b") + ChildController(name: "c") + Get("child_d") { _, _ in "child_d" } + } + } + + let app = Application(responder: router) + try await app.test(.router) { client in + for letter in "abcd" { + try await client.execute(uri: "/parent/child_\(letter)", method: .get) { + XCTAssertEqual(String(buffer: $0.body), "child_\(letter)") + } + } + } + } + + func testRouterControllerWithMiddleware() async throws { + struct TestMiddleware: RouterMiddleware { + func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response { + var response = try await next(request, context) + response.headers[.middleware] = "TestMiddleware" + return response + } + } + + struct ChildController: RouterController { + typealias Context = BasicRouterRequestContext + var body: some RouterMiddleware { + Get("foo") { _, _ in "foo" } + } + } + + struct ParentController: RouterController { + typealias Context = BasicRouterRequestContext + var body: some RouterMiddleware { + RouteGroup("parent") { + TestMiddleware() + ChildController() + Get("bar") { _, _ in "bar" } + } + } + } + + let router = RouterBuilder(context: BasicRouterRequestContext.self) { + ParentController() + } + + let app = Application(responder: router) + try await app.test(.router) { client in + try await client.execute(uri: "/parent/foo", method: .get) { + XCTAssertEqual($0.headers[.middleware], "TestMiddleware") + XCTAssertEqual(String(buffer: $0.body), "foo") + } + + try await client.execute(uri: "/parent/bar", method: .get) { + XCTAssertEqual($0.headers[.middleware], "TestMiddleware") + XCTAssertEqual(String(buffer: $0.body), "bar") + } + } + } +}