diff --git a/Sources/Hummingbird/Codable/CodableProtocols.swift b/Sources/Hummingbird/Codable/CodableProtocols.swift index 2f4bc7354..132101800 100644 --- a/Sources/Hummingbird/Codable/CodableProtocols.swift +++ b/Sources/Hummingbird/Codable/CodableProtocols.swift @@ -28,7 +28,7 @@ public protocol HBRequestDecoder: Sendable { /// - Parameters: /// - type: type to decode to /// - request: request - func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) throws -> T + func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) async throws -> T } /// Default encoder. Outputs request with the swift string description of object diff --git a/Sources/Hummingbird/Codable/RequestDecodable.swift b/Sources/Hummingbird/Codable/RequestDecodable.swift index 502a3824f..edd81b3cb 100644 --- a/Sources/Hummingbird/Codable/RequestDecodable.swift +++ b/Sources/Hummingbird/Codable/RequestDecodable.swift @@ -33,7 +33,7 @@ extension HBRequestDecodable { /// Create using `Codable` interfaces /// - Parameter request: request /// - Throws: HBHTTPError - public init(from request: HBRequest, context: some HBBaseRequestContext) throws { - self = try request.decode(as: Self.self, using: context) + public init(from request: HBRequest, context: some HBBaseRequestContext) async throws { + self = try await request.decode(as: Self.self, using: context) } } diff --git a/Sources/Hummingbird/Router/RouteHandler.swift b/Sources/Hummingbird/Router/RouteHandler.swift index f575605fb..b534d8f74 100644 --- a/Sources/Hummingbird/Router/RouteHandler.swift +++ b/Sources/Hummingbird/Router/RouteHandler.swift @@ -39,7 +39,7 @@ /// ``` public protocol HBRouteHandler { associatedtype _Output - init(from: HBRequest, context: some HBBaseRequestContext) throws + init(from: HBRequest, context: some HBBaseRequestContext) async throws func handle(request: HBRequest, context: some HBBaseRequestContext) async throws -> _Output } @@ -52,7 +52,7 @@ extension HBRouterMethods { use handlerType: Handler.Type ) -> Self where Handler._Output == _Output { return self.on(path, method: method, options: options) { request, context -> _Output in - let handler = try Handler(from: request, context: context) + let handler = try await Handler(from: request, context: context) return try await handler.handle(request: request, context: context) } } diff --git a/Sources/Hummingbird/Router/RouterMethods.swift b/Sources/Hummingbird/Router/RouterMethods.swift index e17da530a..6a5252421 100644 --- a/Sources/Hummingbird/Router/RouterMethods.swift +++ b/Sources/Hummingbird/Router/RouterMethods.swift @@ -21,9 +21,6 @@ public struct HBRouterMethodOptions: OptionSet, Sendable { public init(rawValue: Int) { self.rawValue = rawValue } - - /// don't collate the request body, expect handler to stream it - public static let streamBody: HBRouterMethodOptions = .init(rawValue: 1 << 0) } /// Conform to `HBRouterMethods` to add standard router verb (get, post ...) methods @@ -45,77 +42,66 @@ public protocol HBRouterMethods { extension HBRouterMethods { /// GET path for async closure returning type conforming to ResponseEncodable - @discardableResult public func get( + @discardableResult public func get( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .GET, options: options, use: handler) } /// PUT path for async closure returning type conforming to ResponseEncodable - @discardableResult public func put( + @discardableResult public func put( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .PUT, options: options, use: handler) } /// DELETE path for async closure returning type conforming to ResponseEncodable - @discardableResult public func delete( + @discardableResult public func delete( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .DELETE, options: options, use: handler) } /// HEAD path for async closure returning type conforming to ResponseEncodable - @discardableResult public func head( + @discardableResult public func head( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .HEAD, options: options, use: handler) } /// POST path for async closure returning type conforming to ResponseEncodable - @discardableResult public func post( + @discardableResult public func post( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .POST, options: options, use: handler) } /// PATCH path for async closure returning type conforming to ResponseEncodable - @discardableResult public func patch( + @discardableResult public func patch( _ path: String = "", options: HBRouterMethodOptions = [], - use handler: @Sendable @escaping (HBRequest, Context) async throws -> Output + use handler: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> Self { return on(path, method: .PATCH, options: options, use: handler) } - func constructResponder( + func constructResponder( options: HBRouterMethodOptions, - use closure: @Sendable @escaping (HBRequest, Context) async throws -> Output + use closure: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator ) -> HBCallbackResponder { return HBCallbackResponder { request, context in - if options.contains(.streamBody) { - let output = try await closure(request, context) - return try output.response(from: request, context: context) - } else { - var request = request - do { - request.body = try await request.body.collate(maxSize: context.applicationContext.configuration.maxUploadSize) - } catch { - throw HBHTTPError(.payloadTooLarge) - } - let output = try await closure(request, context) - return try output.response(from: request, context: context) - } + let output = try await closure(request, context) + return try output.response(from: request, context: context) } } } diff --git a/Sources/Hummingbird/Server/Request.swift b/Sources/Hummingbird/Server/Request.swift index 86f50be0c..318a1d778 100644 --- a/Sources/Hummingbird/Server/Request.swift +++ b/Sources/Hummingbird/Server/Request.swift @@ -56,9 +56,9 @@ public struct HBRequest: Sendable { /// Decode request using decoder stored at `HBApplication.decoder`. /// - Parameter type: Type you want to decode to - public func decode(as type: Type.Type, using context: some HBBaseRequestContext) throws -> Type { + public func decode(as type: Type.Type, using context: some HBBaseRequestContext) async throws -> Type { do { - return try context.applicationContext.decoder.decode(type, from: self, context: context) + return try await context.applicationContext.decoder.decode(type, from: self, context: context) } catch DecodingError.dataCorrupted(_) { let message = "The given data was not valid input." throw HBHTTPError(.badRequest, message: message) diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index 8dfb91a0c..4befdf437 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -40,7 +40,7 @@ public enum HBRequestBody: Sendable, AsyncSequence { case .byteBuffer: return self case .stream(let streamer): - return try .byteBuffer(await streamer.collect(upTo: maxSize)) + return try await .byteBuffer(streamer.collect(upTo: maxSize)) } } } diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index 5b08dc6a8..b4d48b7fb 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -42,7 +42,7 @@ extension HTTPChannelHandler { do { try await withGracefulShutdownHandler { try await withThrowingTaskGroup(of: Void.self) { group in - try await asyncChannel.executeThenClose { inbound, outbound in + try await asyncChannel.executeThenClose { inbound, outbound in let responseWriter = HBHTTPServerBodyWriter(outbound: outbound) var iterator = inbound.makeAsyncIterator() while let part = try await iterator.next() { @@ -98,7 +98,7 @@ extension HTTPChannelHandler { } onGracefulShutdown: { // set to cancelled if processingRequest.exchange(.cancelled, ordering: .relaxed) == .idle { - // only close the channel input if it is idle + // only close the channel input if it is idle asyncChannel.channel.close(mode: .input, promise: nil) } } @@ -137,3 +137,10 @@ struct HBHTTPServerBodyWriter: Sendable, HBResponseBodyWriter { try await self.outbound.write(.body(buffer)) } } + +// If we catch a too many bytes error report that as payload too large +extension NIOTooManyBytesError: HBHTTPResponseError { + public var status: NIOHTTP1.HTTPResponseStatus { .payloadTooLarge } + public var headers: NIOHTTP1.HTTPHeaders { [:] } + public func body(allocator: NIOCore.ByteBufferAllocator) -> NIOCore.ByteBuffer? { nil } +} diff --git a/Sources/HummingbirdFoundation/Codable/JSON/JSONCoding.swift b/Sources/HummingbirdFoundation/Codable/JSON/JSONCoding.swift index 731f2bc8b..ea44e3e35 100644 --- a/Sources/HummingbirdFoundation/Codable/JSON/JSONCoding.swift +++ b/Sources/HummingbirdFoundation/Codable/JSON/JSONCoding.swift @@ -40,13 +40,9 @@ extension JSONDecoder: HBRequestDecoder { /// - Parameters: /// - type: Type to decode /// - request: Request to decode from - public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) throws -> T { - guard case .byteBuffer(var buffer) = request.body, - let data = buffer.readData(length: buffer.readableBytes) - else { - throw HBHTTPError(.badRequest) - } - return try self.decode(T.self, from: data) + public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) async throws -> T { + let buffer = try await request.body.collect(upTo: context.applicationContext.configuration.maxUploadSize) + return try self.decode(T.self, from: buffer) } } diff --git a/Sources/HummingbirdFoundation/Codable/URLEncodedForm/URLEncodedForm+Request.swift b/Sources/HummingbirdFoundation/Codable/URLEncodedForm/URLEncodedForm+Request.swift index feeb24803..bbf4f9199 100644 --- a/Sources/HummingbirdFoundation/Codable/URLEncodedForm/URLEncodedForm+Request.swift +++ b/Sources/HummingbirdFoundation/Codable/URLEncodedForm/URLEncodedForm+Request.swift @@ -36,12 +36,9 @@ extension URLEncodedFormDecoder: HBRequestDecoder { /// - Parameters: /// - type: Type to decode /// - request: Request to decode from - public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) throws -> T { - guard case .byteBuffer(var buffer) = request.body, - let string = buffer.readString(length: buffer.readableBytes) - else { - throw HBHTTPError(.badRequest) - } + public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) async throws -> T { + let buffer = try await request.body.collect(upTo: context.applicationContext.configuration.maxUploadSize) + let string = String(buffer: buffer) return try self.decode(T.self, from: string) } } diff --git a/Sources/PerformanceTest/main.swift b/Sources/PerformanceTest/main.swift index b19353915..59eafee04 100644 --- a/Sources/PerformanceTest/main.swift +++ b/Sources/PerformanceTest/main.swift @@ -34,7 +34,7 @@ router.get { _, _ in // request with a body // ./wrk -c 128 -d 15s -t 8 -s scripts/post.lua http://localhost:8080 -router.post(options: .streamBody) { request, _ in +router.post { request, _ in return HBResponse(status: .ok, body: .init(asyncSequence: request.body)) } diff --git a/Tests/HummingbirdFoundationTests/HummingBirdJSONTests.swift b/Tests/HummingbirdFoundationTests/HummingBirdJSONTests.swift index a753d4427..542f83c42 100644 --- a/Tests/HummingbirdFoundationTests/HummingBirdJSONTests.swift +++ b/Tests/HummingbirdFoundationTests/HummingBirdJSONTests.swift @@ -29,7 +29,7 @@ class HummingbirdJSONTests: XCTestCase { func testDecode() async throws { let router = HBRouterBuilder(context: HBTestRouterContext.self) router.put("/user") { request, context -> HTTPResponseStatus in - guard let user = try? request.decode(as: User.self, using: context) else { throw HBHTTPError(.badRequest) } + guard let user = try? await request.decode(as: User.self, using: context) else { throw HBHTTPError(.badRequest) } XCTAssertEqual(user.name, "John Smith") XCTAssertEqual(user.email, "john.smith@email.com") XCTAssertEqual(user.age, 25) diff --git a/Tests/HummingbirdFoundationTests/URLEncodedForm/Application+URLEncodedFormTests.swift b/Tests/HummingbirdFoundationTests/URLEncodedForm/Application+URLEncodedFormTests.swift index e5ca02208..edd9ee996 100644 --- a/Tests/HummingbirdFoundationTests/URLEncodedForm/Application+URLEncodedFormTests.swift +++ b/Tests/HummingbirdFoundationTests/URLEncodedForm/Application+URLEncodedFormTests.swift @@ -29,7 +29,7 @@ class HummingBirdURLEncodedTests: XCTestCase { func testDecode() async throws { let router = HBRouterBuilder(context: HBTestRouterContext.self) router.put("/user") { request, context -> HTTPResponseStatus in - guard let user = try? request.decode(as: User.self, using: context) else { throw HBHTTPError(.badRequest) } + guard let user = try? await request.decode(as: User.self, using: context) else { throw HBHTTPError(.badRequest) } XCTAssertEqual(user.name, "John Smith") XCTAssertEqual(user.email, "john.smith@email.com") XCTAssertEqual(user.age, 25) diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index 585093cb5..378b68033 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -194,26 +194,12 @@ final class ApplicationTests: XCTestCase { } } - func testEventLoopFutureArray() async throws { - let router = HBRouterBuilder(context: HBTestRouterContext.self) - router.patch("array") { _, _ -> [String] in - return ["yes", "no"] - } - let app = HBApplication(responder: router.buildResponder()) - try await app.test(.router) { client in - try await client.XCTExecute(uri: "/array", method: .PATCH) { response in - let body = try XCTUnwrap(response.body) - XCTAssertEqual(String(buffer: body), "[\"yes\", \"no\"]") - } - } - } - func testResponseBody() async throws { let router = HBRouterBuilder(context: HBTestRouterContext.self) router .group("/echo-body") .post { request, _ -> HBResponse in - guard case .byteBuffer(let buffer) = request.body else { return .init(status: .ok) } + let buffer = try await request.body.collect(upTo: .max) return .init(status: .ok, headers: [:], body: .init(byteBuffer: buffer)) } let app = HBApplication(responder: router.buildResponder(), configuration: .init(maxUploadSize: 2 * 1024 * 1024)) @@ -230,10 +216,10 @@ final class ApplicationTests: XCTestCase { /// Test streaming of requests and streaming of responses by streaming the request body into a response streamer func testStreaming() async throws { let router = HBRouterBuilder(context: HBTestRouterContext.self) - router.post("streaming", options: .streamBody) { request, _ -> HBResponse in + router.post("streaming") { request, _ -> HBResponse in return HBResponse(status: .ok, body: .init(asyncSequence: request.body)) } - router.post("size", options: .streamBody) { request, _ -> String in + router.post("size") { request, _ -> String in var size = 0 for try await buffer in request.body { size += buffer.readableBytes @@ -263,7 +249,7 @@ final class ApplicationTests: XCTestCase { /// Test streaming of requests and streaming of responses by streaming the request body into a response streamer func testStreamingSmallBuffer() async throws { let router = HBRouterBuilder(context: HBTestRouterContext.self) - router.post("streaming", options: .streamBody) { request, _ -> HBResponse in + router.post("streaming") { request, _ -> HBResponse in return HBResponse(status: .ok, body: .init(asyncSequence: request.body)) } let app = HBApplication(responder: router.buildResponder()) @@ -310,8 +296,8 @@ final class ApplicationTests: XCTestCase { router .group("/echo-body") .post { request, _ -> ByteBuffer? in - guard case .byteBuffer(let buffer) = request.body, buffer.readableBytes > 0 else { return nil } - return buffer + let buffer = try await request.body.collect(upTo: .max) + return buffer.readableBytes > 0 ? buffer : nil } let app = HBApplication(responder: router.buildResponder()) try await app.test(.router) { client in @@ -402,10 +388,11 @@ final class ApplicationTests: XCTestCase { func testMaxUploadSize() async throws { let router = HBRouterBuilder() - router.post("upload") { _, _ in - "ok" + router.post("upload") { request, context in + _ = try await request.body.collate(maxSize: context.applicationContext.configuration.maxUploadSize) + return "ok" } - router.post("stream", options: .streamBody) { _, _ in + router.post("stream") { _, _ in "ok" } let app = HBApplication(responder: router.buildResponder(), configuration: .init(maxUploadSize: 64 * 1024)) diff --git a/Tests/HummingbirdTests/FileIOTests.swift b/Tests/HummingbirdTests/FileIOTests.swift index cf8df5757..5fa6a60cf 100644 --- a/Tests/HummingbirdTests/FileIOTests.swift +++ b/Tests/HummingbirdTests/FileIOTests.swift @@ -71,7 +71,7 @@ class FileIOTests: XCTestCase { func testWriteLargeFile() async throws { let filename = "testWriteLargeFile.txt" let router = HBRouterBuilder(context: HBTestRouterContext.self) - router.put("store", options: .streamBody) { request, context -> HTTPResponseStatus in + router.put("store") { request, context -> HTTPResponseStatus in let fileIO = HBFileIO(threadPool: context.threadPool) try await fileIO.writeFile(contents: request.body, path: filename, context: context, logger: context.logger) return .ok diff --git a/Tests/HummingbirdTests/MiddlewareTests.swift b/Tests/HummingbirdTests/MiddlewareTests.swift index b571e34ec..a70f8f137 100644 --- a/Tests/HummingbirdTests/MiddlewareTests.swift +++ b/Tests/HummingbirdTests/MiddlewareTests.swift @@ -157,7 +157,7 @@ final class MiddlewareTests: XCTestCase { let router = HBRouterBuilder(context: HBTestRouterContext.self) router.group() .add(middleware: TransformMiddleware()) - .get("test", options: .streamBody) { request, _ in + .get("test") { request, _ in return HBResponse(status: .ok, body: .init(asyncSequence: request.body)) } let app = HBApplication(responder: router.buildResponder()) diff --git a/Tests/HummingbirdTests/PersistTests.swift b/Tests/HummingbirdTests/PersistTests.swift index 5b161b196..0fd296b67 100644 --- a/Tests/HummingbirdTests/PersistTests.swift +++ b/Tests/HummingbirdTests/PersistTests.swift @@ -24,14 +24,14 @@ final class PersistTests: XCTestCase { let persist = HBMemoryPersistDriver() router.put("/persist/:tag") { request, context -> HTTPResponseStatus in - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.badRequest) } + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.set(key: tag, value: String(buffer: buffer)) return .ok } router.put("/persist/:tag/:time") { request, context -> HTTPResponseStatus in guard let time = context.parameters.get("time", as: Int.self) else { throw HBHTTPError(.badRequest) } - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.badRequest) } + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.set(key: tag, value: String(buffer: buffer), expires: .seconds(time)) return .ok @@ -65,7 +65,7 @@ final class PersistTests: XCTestCase { let (router, persist) = try createRouter() router.put("/create/:tag") { request, context -> HTTPResponseStatus in - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.badRequest) } + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.create(key: tag, value: String(buffer: buffer)) return .ok @@ -84,7 +84,7 @@ final class PersistTests: XCTestCase { func testDoubleCreateFail() async throws { let (router, persist) = try createRouter() router.put("/create/:tag") { request, context -> HTTPResponseStatus in - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.badRequest) } + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") do { try await persist.create(key: tag, value: String(buffer: buffer)) @@ -154,7 +154,7 @@ final class PersistTests: XCTestCase { let (router, persist) = try createRouter() router.put("/codable/:tag") { request, context -> HTTPResponseStatus in guard let tag = context.parameters.get("tag") else { throw HBHTTPError(.badRequest) } - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.badRequest) } + let buffer = try await request.body.collect(upTo: .max) try await persist.set(key: tag, value: TestCodable(buffer: String(buffer: buffer))) return .ok }