diff --git a/Benchmarks/Benchmarks/Router/RouterBenchmarks.swift b/Benchmarks/Benchmarks/Router/RouterBenchmarks.swift index 4c35f67ed..0de477f6c 100644 --- a/Benchmarks/Benchmarks/Router/RouterBenchmarks.swift +++ b/Benchmarks/Benchmarks/Router/RouterBenchmarks.swift @@ -45,7 +45,7 @@ extension Benchmark { context: Context.Type = BasicBenchmarkContext.self, configuration: Benchmark.Configuration = Benchmark.defaultConfiguration, request: HTTPRequest, - writeBody: @escaping @Sendable (HBStreamedRequestBody) async throws -> Void = { _ in }, + writeBody: @escaping @Sendable (HBStreamedRequestBody.InboundStream.TestSource) async throws -> Void = { _ in }, setupRouter: @escaping @Sendable (HBRouter) async throws -> Void ) { let router = HBRouter(context: Context.self) @@ -60,15 +60,17 @@ extension Benchmark { allocator: ByteBufferAllocator(), logger: Logger(label: "Benchmark") ) - let requestBodyStream = HBStreamedRequestBody() - let requestBody = HBRequestBody.stream(requestBodyStream) + let (inbound, source) = NIOAsyncChannelInboundStream.makeTestingStream() + let streamer = HBStreamedRequestBody(iterator: inbound.makeAsyncIterator()) + let requestBody = HBRequestBody.stream(streamer) let hbRequest = HBRequest(head: request, body: requestBody) group.addTask { let response = try await responder.respond(to: hbRequest, context: context) _ = try await response.body.write(BenchmarkBodyWriter()) } - try await writeBody(requestBodyStream) - requestBodyStream.finish() + try await writeBody(source) + source.yield(.end(nil)) + source.finish() } } } @@ -100,10 +102,10 @@ func routerBenchmarks() { configuration: .init(warmupIterations: 10), request: .init(method: .put, scheme: "http", authority: "localhost", path: "/") ) { bodyStream in - await bodyStream.send(buffer) - await bodyStream.send(buffer) - await bodyStream.send(buffer) - await bodyStream.send(buffer) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) } setupRouter: { router in router.put { request, _ in let body = try await request.body.collate(maxSize: .max) @@ -116,10 +118,10 @@ func routerBenchmarks() { configuration: .init(warmupIterations: 10), request: .init(method: .post, scheme: "http", authority: "localhost", path: "/") ) { bodyStream in - await bodyStream.send(buffer) - await bodyStream.send(buffer) - await bodyStream.send(buffer) - await bodyStream.send(buffer) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) + bodyStream.yield(.body(buffer)) } setupRouter: { router in router.post { request, _ in HBResponse(status: .ok, headers: [:], body: .init { writer in diff --git a/Sources/Hummingbird/Codable/JSON/JSONCoding.swift b/Sources/Hummingbird/Codable/JSON/JSONCoding.swift index 8499839cf..04435a528 100644 --- a/Sources/Hummingbird/Codable/JSON/JSONCoding.swift +++ b/Sources/Hummingbird/Codable/JSON/JSONCoding.swift @@ -40,7 +40,7 @@ extension JSONDecoder: HBRequestDecoder { /// - type: Type to decode /// - request: Request to decode from public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) async throws -> T { - let buffer = try await request.body.collate(maxSize: context.maxUploadSize) + let buffer = try await request.body.collect(upTo: context.maxUploadSize) return try self.decode(T.self, from: buffer) } } diff --git a/Sources/Hummingbird/Codable/URLEncodedForm/URLEncodedForm+Request.swift b/Sources/Hummingbird/Codable/URLEncodedForm/URLEncodedForm+Request.swift index 6be53514a..ac706776f 100644 --- a/Sources/Hummingbird/Codable/URLEncodedForm/URLEncodedForm+Request.swift +++ b/Sources/Hummingbird/Codable/URLEncodedForm/URLEncodedForm+Request.swift @@ -35,7 +35,7 @@ extension URLEncodedFormDecoder: HBRequestDecoder { /// - type: Type to decode /// - request: Request to decode from public func decode(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) async throws -> T { - let buffer = try await request.body.collate(maxSize: context.maxUploadSize) + let buffer = try await request.body.collect(upTo: context.maxUploadSize) let string = String(buffer: buffer) return try self.decode(T.self, from: string) } diff --git a/Sources/Hummingbird/Exports.swift b/Sources/Hummingbird/Exports.swift index a1612872d..06d7ce39b 100644 --- a/Sources/Hummingbird/Exports.swift +++ b/Sources/Hummingbird/Exports.swift @@ -16,7 +16,7 @@ @_exported import struct HummingbirdCore.HBHTTPError @_exported import protocol HummingbirdCore.HBHTTPResponseError @_exported import struct HummingbirdCore.HBRequest -@_exported import enum HummingbirdCore.HBRequestBody +@_exported import struct HummingbirdCore.HBRequestBody @_exported import struct HummingbirdCore.HBResponse @_exported import struct HummingbirdCore.HBResponseBody @_exported import protocol HummingbirdCore.HBResponseBodyWriter diff --git a/Sources/Hummingbird/Files/FileIO.swift b/Sources/Hummingbird/Files/FileIO.swift index 63e37edef..a720903d3 100644 --- a/Sources/Hummingbird/Files/FileIO.swift +++ b/Sources/Hummingbird/Files/FileIO.swift @@ -108,11 +108,8 @@ public struct HBFileIO: Sendable { try? handle.close() } context.logger.debug("[FileIO] PUT", metadata: ["file": .string(path)]) - switch contents { - case .byteBuffer(let buffer): - try await self.writeFile(buffer: buffer, handle: handle, on: eventLoop) - case .stream(let streamer): - try await self.writeFile(asyncSequence: streamer, handle: handle, on: eventLoop) + for try await buffer in contents { + try await self.fileIO.write(fileHandle: handle, buffer: buffer, eventLoop: eventLoop).get() } } @@ -153,20 +150,4 @@ public struct HBFileIO: Sendable { try handle.close() } } - - /// write byte buffer to file - func writeFile(buffer: ByteBuffer, handle: NIOFileHandle, on eventLoop: EventLoop) async throws { - return try await self.fileIO.write(fileHandle: handle, buffer: buffer, eventLoop: eventLoop).get() - } - - /// write output of streamer to file - func writeFile( - asyncSequence: BufferSequence, - handle: NIOFileHandle, - on eventLoop: EventLoop - ) async throws where BufferSequence.Element == ByteBuffer { - for try await buffer in asyncSequence { - try await self.fileIO.write(fileHandle: handle, buffer: buffer, eventLoop: eventLoop).get() - } - } } diff --git a/Sources/Hummingbird/Server/Request.swift b/Sources/Hummingbird/Server/Request.swift index edcbf155c..97c6bd501 100644 --- a/Sources/Hummingbird/Server/Request.swift +++ b/Sources/Hummingbird/Server/Request.swift @@ -24,8 +24,8 @@ extension HBRequest { /// - Parameter context: request context /// - Returns: Collated body public mutating func collateBody(context: some HBBaseRequestContext) async throws -> ByteBuffer { - let byteBuffer = try await self.body.collate(maxSize: context.maxUploadSize) - self.body = .byteBuffer(byteBuffer) + let byteBuffer = try await self.body.collect(upTo: context.maxUploadSize) + self.body = .init(buffer: byteBuffer) return byteBuffer } diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index 47cb89e5d..bc4c2d226 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -2,7 +2,7 @@ // // This source file is part of the Hummingbird server framework project // -// Copyright (c) 2023 the Hummingbird authors +// Copyright (c) 2023-2024 the Hummingbird authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information @@ -12,79 +12,289 @@ // //===----------------------------------------------------------------------===// -import AsyncAlgorithms +import Collections +import NIOConcurrencyHelpers import NIOCore +import NIOHTTPTypes -public enum HBRequestBody: Sendable, AsyncSequence { - case byteBuffer(ByteBuffer) - case stream(HBStreamedRequestBody) +/// Request Body +/// +/// Can be either a stream of ByteBuffers or a single ByteBuffer +public struct HBRequestBody: Sendable, AsyncSequence { + @usableFromInline + internal enum _Backing: Sendable { + case byteBuffer(ByteBuffer) + case stream(AnyAsyncSequence) + } + + @usableFromInline + internal let _backing: _Backing + + @usableFromInline + init(_ backing: _Backing) { + self._backing = backing + } + /// Initialise ``HBRequestBody`` from ByteBuffer + /// - Parameter buffer: ByteBuffer + public init(buffer: ByteBuffer) { + self.init(.byteBuffer(buffer)) + } + + @inlinable + init(asyncSequence: AS) where AS.Element == ByteBuffer { + self.init(.stream(.init(asyncSequence))) + } +} + +/// AsyncSequence protocol requirements +extension HBRequestBody { public typealias Element = ByteBuffer - public typealias AsyncIterator = HBStreamedRequestBody.AsyncIterator - public func makeAsyncIterator() -> HBStreamedRequestBody.AsyncIterator { - switch self { - case .byteBuffer: - /// The server always creates the HBRequestBody as a stream. If it is converted - /// into a single ByteBuffer it cannot be treated as a stream after that - preconditionFailure("Cannot convert collapsed request body back into a sequence") - case .stream(let streamer): - return streamer.makeAsyncIterator() + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + var iterator: AnyAsyncSequence.AsyncIterator + + @usableFromInline + init(_ iterator: AnyAsyncSequence.AsyncIterator) { + self.iterator = iterator + } + + @inlinable + public mutating func next() async throws -> ByteBuffer? { + try await self.iterator.next() } } - /// Return as a single ByteBuffer. This function is required as `ByteBuffer.collect(upTo:)` - /// assumes the request body can be iterated. - public func collate(maxSize: Int) async throws -> ByteBuffer { - switch self { + @inlinable + public func makeAsyncIterator() -> AsyncIterator { + switch self._backing { case .byteBuffer(let buffer): - return buffer - case .stream: - return try await collect(upTo: maxSize) + return .init(AnyAsyncSequence(ByteBufferRequestBody(byteBuffer: buffer)).makeAsyncIterator()) + case .stream(let stream): + return .init(stream.makeAsyncIterator()) } } } -/// A type that represents an HTTP request body. -public struct HBStreamedRequestBody: Sendable, AsyncSequence { - public typealias Element = ByteBuffer +/// Extend HBRequestBody to create request body streams backed by `NIOThrowingAsyncSequenceProducer`. +extension HBRequestBody { + @usableFromInline + typealias Producer = NIOThrowingAsyncSequenceProducer< + ByteBuffer, + any Error, + NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, + Delegate + > - public struct AsyncIterator: AsyncIteratorProtocol { - public typealias Element = ByteBuffer + /// Delegate for NIOThrowingAsyncSequenceProducer + @usableFromInline + final class Delegate: NIOAsyncSequenceProducerDelegate { + let checkedContinuations: NIOLockedValueBox>> - fileprivate var underlyingIterator: AsyncThrowingChannel.AsyncIterator + @usableFromInline + init() { + self.checkedContinuations = .init([]) + } - public mutating func next() async throws -> ByteBuffer? { - try await self.underlyingIterator.next() + @usableFromInline + func produceMore() { + self.checkedContinuations.withLockedValue { + if let cont = $0.popFirst() { + cont.resume() + } + } + } + + @usableFromInline + func didTerminate() { + self.checkedContinuations.withLockedValue { + while let cont = $0.popFirst() { + cont.resume() + } + } + } + + @usableFromInline + func waitForProduceMore() async { + await withCheckedContinuation { (cont: CheckedContinuation) in + self.checkedContinuations.withLockedValue { + $0.append(cont) + } + } + } + } + + /// A source used for driving a ``HBRequestBody`` stream. + public final class Source { + @usableFromInline + let source: Producer.Source + @usableFromInline + let delegate: Delegate + @usableFromInline + var waitForProduceMore: Bool + + @usableFromInline + init(source: Producer.Source, delegate: Delegate) { + self.source = source + self.delegate = delegate + self.waitForProduceMore = .init(false) + } + + /// Yields the element to the inbound stream. + /// + /// This function implements back pressure in that it will wait if the producer + /// sequence indicates the Source should produce more ByteBuffers. + /// + /// - Parameter element: The element to yield to the inbound stream. + @inlinable + public func yield(_ element: ByteBuffer) async throws { + // if previous call indicated we should stop producing wait until the delegate + // says we can start producing again + if self.waitForProduceMore { + await self.delegate.waitForProduceMore() + self.waitForProduceMore = false + } + let result = self.source.yield(element) + if result == .stopProducing { + self.waitForProduceMore = true + } + } + + /// Finished the inbound stream. + @inlinable + public func finish() { + self.source.finish() + } + + /// Finished the inbound stream. + /// + /// - Parameter error: The error to throw + @inlinable + public func finish(_ error: Error) { + self.source.finish(error) } } - /// HBRequestBody is internally represented by AsyncThrowingChannel - private var channel: AsyncThrowingChannel + /// Make a new ``HBRequestBody`` stream + /// - Returns: The new `HBRequestBody` and a source to yield ByteBuffers to the `HBRequestBody`. + @inlinable + public static func makeStream() -> (HBRequestBody, Source) { + let delegate = Delegate() + let newSequence = Producer.makeSequence( + backPressureStrategy: .init(lowWatermark: 2, highWatermark: 4), + finishOnDeinit: false, + delegate: delegate + ) + return (.init(asyncSequence: newSequence.sequence), Source(source: newSequence.source, delegate: delegate)) + } +} + +/// Request body that is a stream of ByteBuffers sourced from a NIOAsyncChannelInboundStream. +/// +/// This is a unicast async sequence that allows a single iterator to be created. +@usableFromInline +final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { + @usableFromInline + typealias Element = ByteBuffer + @usableFromInline + typealias InboundStream = NIOAsyncChannelInboundStream - /// Creates a new HTTP request body - @_spi(HBXCT) public init() { - self.channel = .init() + @usableFromInline + internal let underlyingIterator: UnsafeTransfer.AsyncIterator> + @usableFromInline + internal let alreadyIterated: NIOLockedValueBox + + /// Initialize NIOAsyncChannelRequestBody from AsyncIterator of a NIOAsyncChannelInboundStream + @inlinable + init(iterator: InboundStream.AsyncIterator) { + self.underlyingIterator = .init(iterator) + self.alreadyIterated = .init(false) } - public func makeAsyncIterator() -> AsyncIterator { - AsyncIterator(underlyingIterator: self.channel.makeAsyncIterator()) + /// Async Iterator for NIOAsyncChannelRequestBody + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + internal var underlyingIterator: InboundStream.AsyncIterator + @usableFromInline + internal var done: Bool + + @inlinable + init(underlyingIterator: InboundStream.AsyncIterator, done: Bool = false) { + self.underlyingIterator = underlyingIterator + self.done = done + } + + @inlinable + mutating func next() async throws -> ByteBuffer? { + if self.done { return nil } + // if we are still expecting parts and the iterator finishes. + // In this case I think we can just assume we hit an .end + guard let part = try await self.underlyingIterator.next() else { return nil } + switch part { + case .body(let buffer): + return buffer + case .end: + self.done = true + return nil + default: + throw HTTPChannelError.unexpectedHTTPPart(part) + } + } + } + + @inlinable + func makeAsyncIterator() -> AsyncIterator { + // verify if an iterator has already been created. If it has then create an + // iterator that returns nothing. This could be a precondition failure (currently + // an assert) as you should not be allowed to do this. + let done = self.alreadyIterated.withLockedValue { + assert($0 == false, "Can only create iterator from request body once") + let done = $0 + $0 = true + return done + } + return AsyncIterator(underlyingIterator: self.underlyingIterator.wrappedValue, done: done) } } -extension HBStreamedRequestBody { - /// push a single ByteBuffer to the HTTP request body stream - @_spi(HBXCT) public func send(_ buffer: ByteBuffer) async { - await self.channel.send(buffer) +/// Request body stream that is a single ByteBuffer +/// +/// This is used when converting a ByteBuffer back to a stream of ByteBuffers +@usableFromInline +struct ByteBufferRequestBody: Sendable, AsyncSequence { + @usableFromInline + typealias Element = ByteBuffer + + @usableFromInline + init(byteBuffer: ByteBuffer) { + self.byteBuffer = byteBuffer } - /// pass error to HTTP request body - @_spi(HBXCT) public func fail(_ error: Error) { - self.channel.fail(error) + @usableFromInline + struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + var byteBuffer: ByteBuffer + @usableFromInline + var iterated: Bool + + init(byteBuffer: ByteBuffer) { + self.byteBuffer = byteBuffer + self.iterated = false + } + + @inlinable + mutating func next() async throws -> ByteBuffer? { + guard self.iterated == false else { return nil } + self.iterated = true + return self.byteBuffer + } } - /// Finish HTTP request body stream - @_spi(HBXCT) public func finish() { - self.channel.finish() + @usableFromInline + func makeAsyncIterator() -> AsyncIterator { + .init(byteBuffer: self.byteBuffer) } + + let byteBuffer: ByteBuffer } diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index e3fa9f1e1..2972f70a0 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -27,6 +27,7 @@ public protocol HTTPChannelHandler: HBChildChannel { /// Internal error thrown when an unexpected HTTP part is received eg we didn't receive /// a head part when we expected one +@usableFromInline enum HTTPChannelError: Error { case unexpectedHTTPPart(HTTPRequestPart) case closeConnection @@ -43,57 +44,61 @@ extension HTTPChannelHandler { let processingRequest = ManagedAtomic(HTTPState.idle) do { try await withGracefulShutdownHandler { - try await withThrowingTaskGroup(of: Void.self) { group in - try await asyncChannel.executeThenClose { inbound, outbound in - let responseWriter = HBHTTPServerBodyWriter(outbound: outbound) - var iterator = inbound.makeAsyncIterator() - while let part = try await iterator.next() { - // set to processing unless it is cancelled then exit - guard processingRequest.exchange(.processing, ordering: .relaxed) == .idle else { break } - guard case .head(let head) = part else { - throw HTTPChannelError.unexpectedHTTPPart(part) - } - let bodyStream = HBStreamedRequestBody() - let body = HBRequestBody.stream(bodyStream) - let request = HBRequest(head: head, body: body) - // add task processing request and writing response - group.addTask { - let response: HBResponse - do { - response = try await self.responder(request, asyncChannel.channel) - } catch { - response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator) - } - do { - try await outbound.write(.head(response.head)) - let tailHeaders = try await response.body.write(responseWriter) - try await outbound.write(.end(tailHeaders)) - // flush request body - for try await _ in request.body {} - } catch { - // flush request body - for try await _ in request.body {} - throw error - } - if request.headers[.connection] == "close" { - throw HTTPChannelError.closeConnection - } - } - // send body parts to request - do { - // pass body part to request - while case .body(let buffer) = try await iterator.next() { - await bodyStream.send(buffer) - } - bodyStream.finish() - } catch { - // pass failed to read full http body to request - bodyStream.fail(error) - } - try await group.next() - // set to idle unless it is cancelled then exit - guard processingRequest.exchange(.idle, ordering: .relaxed) == .processing else { break } + try await asyncChannel.executeThenClose { inbound, outbound in + let responseWriter = HBHTTPServerBodyWriter(outbound: outbound) + var iterator = inbound.makeAsyncIterator() + + // read first part, verify it is a head + guard let part = try await iterator.next() else { return } + guard case .head(var head) = part else { + throw HTTPChannelError.unexpectedHTTPPart(part) + } + + while true { + // set to processing unless it is cancelled then exit + guard processingRequest.exchange(.processing, ordering: .relaxed) == .idle else { break } + + let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) + let request = HBRequest(head: head, body: .init(asyncSequence: bodyStream)) + let response: HBResponse + do { + response = try await self.responder(request, asyncChannel.channel) + } catch { + response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator) + } + do { + try await outbound.write(.head(response.head)) + let tailHeaders = try await response.body.write(responseWriter) + try await outbound.write(.end(tailHeaders)) + } catch { + throw error + } + if request.headers[.connection] == "close" { + throw HTTPChannelError.closeConnection } + // set to idle unless it is cancelled then exit + guard processingRequest.exchange(.idle, ordering: .relaxed) == .processing else { break } + + // Flush current request + // read until we don't have a body part + var part: HTTPRequestPart? + while true { + part = try await iterator.next() + guard case .body = part else { break } + } + // if we have an end then read the next part + if case .end = part { + part = try await iterator.next() + } + + // if part is nil break out of loop + guard let part else { + break + } + + // part should be a head, if not throw error + guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } + head = newHead } } } onGracefulShutdown: { diff --git a/Sources/HummingbirdCore/Utils/AnyAsyncSequence.swift b/Sources/HummingbirdCore/Utils/AnyAsyncSequence.swift new file mode 100644 index 000000000..3359cc206 --- /dev/null +++ b/Sources/HummingbirdCore/Utils/AnyAsyncSequence.swift @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@usableFromInline +struct AnyAsyncSequence: AsyncSequence { + @usableFromInline + typealias AsyncIteratorNextCallback = () async throws -> Element? + + @usableFromInline + let makeAsyncIteratorCallback: @Sendable () -> AsyncIteratorNextCallback + + @inlinable + init(_ base: AS) where AS.Element == Element, AS: Sendable { + self.makeAsyncIteratorCallback = { + var iterator = base.makeAsyncIterator() + return { + try await iterator.next() + } + } + } + + @usableFromInline + struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + let nextCallback: AsyncIteratorNextCallback + + @usableFromInline + init(nextCallback: @escaping AsyncIteratorNextCallback) { + self.nextCallback = nextCallback + } + + @inlinable + func next() async throws -> Element? { + try await self.nextCallback() + } + } + + @inlinable + func makeAsyncIterator() -> AsyncIterator { + .init(nextCallback: self.makeAsyncIteratorCallback()) + } +} + +extension AnyAsyncSequence: Sendable where Element: Sendable {} diff --git a/Sources/HummingbirdCore/Utils/UnsafeTransfer.swift b/Sources/HummingbirdCore/Utils/UnsafeTransfer.swift new file mode 100644 index 000000000..259d87a54 --- /dev/null +++ b/Sources/HummingbirdCore/Utils/UnsafeTransfer.swift @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`. +/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler. +/// It can be used similar to `@unsafe Sendable` but for values instead of types. +@usableFromInline +struct UnsafeTransfer { + @usableFromInline + var wrappedValue: Wrapped + + @inlinable + init(_ wrappedValue: Wrapped) { + self.wrappedValue = wrappedValue + } +} + +extension UnsafeTransfer: @unchecked Sendable {} + +extension UnsafeTransfer: Equatable where Wrapped: Equatable {} +extension UnsafeTransfer: Hashable where Wrapped: Hashable {} + +/// ``UnsafeMutableTransferBox`` can be used to make non-`Sendable` values `Sendable` and mutable. +/// It can be used to capture local mutable values in a `@Sendable` closure and mutate them from within the closure. +/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler and does not add any synchronisation. +@usableFromInline +final class UnsafeMutableTransferBox { + @usableFromInline + var wrappedValue: Wrapped + + @inlinable + init(_ wrappedValue: Wrapped) { + self.wrappedValue = wrappedValue + } +} + +extension UnsafeMutableTransferBox: @unchecked Sendable {} diff --git a/Sources/HummingbirdXCT/HBXCTRouter.swift b/Sources/HummingbirdXCT/HBXCTRouter.swift index 3ba6dea01..f14448c65 100644 --- a/Sources/HummingbirdXCT/HBXCTRouter.swift +++ b/Sources/HummingbirdXCT/HBXCTRouter.swift @@ -19,6 +19,7 @@ import HTTPTypes import Logging import NIOConcurrencyHelpers import NIOCore +import NIOHTTPTypes import NIOPosix import ServiceLifecycle @@ -90,10 +91,10 @@ struct HBXCTRouter: HBXCTApplication where Responder.Con func execute(uri: String, method: HTTPRequest.Method, headers: HTTPFields, body: ByteBuffer?) async throws -> HBXCTResponse { return try await withThrowingTaskGroup(of: HBXCTResponse.self) { group in - let streamer = HBStreamedRequestBody() + let (stream, source) = HBRequestBody.makeStream() let request = HBRequest( head: .init(method: method, scheme: "http", authority: "localhost", path: uri, headerFields: headers), - body: .stream(streamer) + body: stream ) let logger = self.logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) let context = self.makeContext(logger) @@ -110,7 +111,6 @@ struct HBXCTRouter: HBXCTApplication where Responder.Con } let responseWriter = RouterResponseWriter() let trailerHeaders = try await response.body.write(responseWriter) - for try await _ in request.body {} return responseWriter.collated.withLockedValue { collated in HBXCTResponse(head: response.head, body: collated, trailerHeaders: trailerHeaders) } @@ -120,10 +120,10 @@ struct HBXCTRouter: HBXCTApplication where Responder.Con while body.readableBytes > 0 { let chunkSize = min(32 * 1024, body.readableBytes) let buffer = body.readSlice(length: chunkSize)! - await streamer.send(buffer) + try await source.yield(buffer) } } - streamer.finish() + source.finish() return try await group.next()! } } diff --git a/Tests/HummingbirdCoreTests/CoreTests.swift b/Tests/HummingbirdCoreTests/CoreTests.swift index c91f8d817..f1cc58c6b 100644 --- a/Tests/HummingbirdCoreTests/CoreTests.swift +++ b/Tests/HummingbirdCoreTests/CoreTests.swift @@ -85,7 +85,7 @@ class HummingBirdCoreTests: XCTestCase { func testConsumeBody() async throws { try await testServer( responder: { request, _ in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) return HBResponse(status: .ok, body: .init(byteBuffer: buffer)) }, configuration: .init(address: .hostname(port: 0)), @@ -203,7 +203,7 @@ class HummingBirdCoreTests: XCTestCase { } try await testServer( responder: { request, _ in - _ = try await request.body.collate(maxSize: .max) + _ = try await request.body.collect(upTo: .max) return HBResponse(status: .ok) }, httpChannelSetup: .http1(additionalChannelHandlers: [CreateErrorHandler()]), @@ -269,7 +269,7 @@ class HummingBirdCoreTests: XCTestCase { } try await testServer( responder: { request, _ in - _ = try await request.body.collate(maxSize: .max) + _ = try await request.body.collect(upTo: .max) return .init(status: .ok) }, httpChannelSetup: .http1(additionalChannelHandlers: [HTTPServerIncompleteRequest(), IdleStateHandler(readTimeout: .seconds(1))]), @@ -292,7 +292,7 @@ class HummingBirdCoreTests: XCTestCase { func testWriteIdleTimeout() async throws { try await testServer( responder: { request, _ in - _ = try await request.body.collate(maxSize: .max) + _ = try await request.body.collect(upTo: .max) return .init(status: .ok) }, httpChannelSetup: .http1(additionalChannelHandlers: [IdleStateHandler(writeTimeout: .seconds(1))]), diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index 31d20e0e9..846752566 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -191,7 +191,7 @@ final class ApplicationTests: XCTestCase { router .group("/echo-body") .post { request, _ -> HBResponse in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) return .init(status: .ok, headers: [:], body: .init(byteBuffer: buffer)) } let app = HBApplication(responder: router.buildResponder()) @@ -268,7 +268,7 @@ final class ApplicationTests: XCTestCase { let router = HBRouter() router.middlewares.add(CollateMiddleware()) router.put("/hello") { request, _ -> String in - guard case .byteBuffer(let buffer) = request.body else { throw HBHTTPError(.internalServerError) } + let buffer = try await request.body.collect(upTo: .max) return buffer.readableBytes.description } let app = HBApplication(responder: router.buildResponder()) @@ -282,12 +282,33 @@ final class ApplicationTests: XCTestCase { } } + func testDoubleStreaming() async throws { + let router = HBRouter() + router.post("size") { request, context -> String in + var request = request + _ = try await request.collateBody(context: context) + var size = 0 + for try await buffer in request.body { + size += buffer.readableBytes + } + return size.description + } + let app = HBApplication(responder: router.buildResponder()) + + try await app.test(.router) { client in + let buffer = self.randomBuffer(size: 100_000) + try await client.XCTExecute(uri: "/size", method: .post, body: buffer) { response in + XCTAssertEqual(String(buffer: response.body), "100000") + } + } + } + func testOptional() async throws { let router = HBRouter() router .group("/echo-body") .post { request, _ -> ByteBuffer? in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) return buffer.readableBytes > 0 ? buffer : nil } let app = HBApplication(responder: router.buildResponder()) @@ -392,7 +413,7 @@ final class ApplicationTests: XCTestCase { } let router = HBRouter(context: MaxUploadRequestContext.self) router.post("upload") { request, context in - _ = try await request.body.collate(maxSize: context.maxUploadSize) + _ = try await request.body.collect(upTo: context.maxUploadSize) return "ok" } router.post("stream") { _, _ in @@ -608,14 +629,14 @@ final class ApplicationTests: XCTestCase { /// test we can create out own application type conforming to HBApplicationProtocol func testBidirectionalStreaming() async throws { - let buffer = randomBuffer(size: 1024 * 1024) + let buffer = self.randomBuffer(size: 1024 * 1024) let router = HBRouter() router.post("/") { request, context -> HBResponse in .init( status: .ok, body: .init { writer in for try await buffer in request.body { - let processed = context.allocator.buffer(bytes: buffer.readableBytesView.map {$0 ^ 0xff }) + let processed = context.allocator.buffer(bytes: buffer.readableBytesView.map { $0 ^ 0xFF }) try await writer.write(processed) } } @@ -624,7 +645,7 @@ final class ApplicationTests: XCTestCase { let app = HBApplication(router: router) try await app.test(.live) { client in try await client.XCTExecute(uri: "/", method: .post, body: buffer) { response in - XCTAssertEqual(response.body, ByteBuffer(bytes: buffer.readableBytesView.map {$0 ^ 0xff })) + XCTAssertEqual(response.body, ByteBuffer(bytes: buffer.readableBytesView.map { $0 ^ 0xFF })) } } } diff --git a/Tests/HummingbirdTests/PersistTests.swift b/Tests/HummingbirdTests/PersistTests.swift index 23bb5b6bc..7253e8f5c 100644 --- a/Tests/HummingbirdTests/PersistTests.swift +++ b/Tests/HummingbirdTests/PersistTests.swift @@ -24,14 +24,14 @@ final class PersistTests: XCTestCase { let persist = HBMemoryPersistDriver() router.put("/persist/:tag") { request, context -> HTTPResponse.Status in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.set(key: tag, value: String(buffer: buffer)) return .ok } router.put("/persist/:tag/:time") { request, context -> HTTPResponse.Status in guard let time = context.parameters.get("time", as: Int.self) else { throw HBHTTPError(.badRequest) } - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.set(key: tag, value: String(buffer: buffer), expires: .seconds(time)) return .ok @@ -64,7 +64,7 @@ final class PersistTests: XCTestCase { let (router, persist) = try createRouter() router.put("/create/:tag") { request, context -> HTTPResponse.Status in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") try await persist.create(key: tag, value: String(buffer: buffer)) return .ok @@ -82,7 +82,7 @@ final class PersistTests: XCTestCase { func testDoubleCreateFail() async throws { let (router, persist) = try createRouter() router.put("/create/:tag") { request, context -> HTTPResponse.Status in - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) let tag = try context.parameters.require("tag") do { try await persist.create(key: tag, value: String(buffer: buffer)) @@ -150,7 +150,7 @@ final class PersistTests: XCTestCase { let (router, persist) = try createRouter() router.put("/codable/:tag") { request, context -> HTTPResponse.Status in guard let tag = context.parameters.get("tag") else { throw HBHTTPError(.badRequest) } - let buffer = try await request.body.collate(maxSize: .max) + let buffer = try await request.body.collect(upTo: .max) try await persist.set(key: tag, value: TestCodable(buffer: String(buffer: buffer))) return .ok }