diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index a559ad27..5a0600b9 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -24,7 +24,8 @@ public struct RequestBody: Sendable, AsyncSequence { @usableFromInline internal enum _Backing: Sendable { case byteBuffer(ByteBuffer) - case stream(AnyAsyncSequence) + case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody) + case anyAsyncSequence(AnyAsyncSequence) } @usableFromInline @@ -37,15 +38,23 @@ public struct RequestBody: Sendable, AsyncSequence { /// Initialise ``RequestBody`` from ByteBuffer /// - Parameter buffer: ByteBuffer + @inlinable public init(buffer: ByteBuffer) { self.init(.byteBuffer(buffer)) } + /// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers + /// - Parameter asyncSequence: AsyncSequence + @inlinable + package init(nioAsyncChannelInbound: NIOAsyncChannelRequestBody) { + self.init(.nioAsyncChannelRequestBody(nioAsyncChannelInbound)) + } + /// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers /// - Parameter asyncSequence: AsyncSequence @inlinable public init(asyncSequence: AS) where AS.Element == ByteBuffer { - self.init(.stream(.init(asyncSequence))) + self.init(.anyAsyncSequence(.init(asyncSequence))) } } @@ -55,16 +64,39 @@ extension RequestBody { public struct AsyncIterator: AsyncIteratorProtocol { @usableFromInline - var iterator: AnyAsyncSequence.AsyncIterator + internal enum _Backing { + case byteBuffer(ByteBuffer) + case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody.AsyncIterator) + case anyAsyncSequence(AnyAsyncSequence.AsyncIterator) + case done + } @usableFromInline - init(_ iterator: AnyAsyncSequence.AsyncIterator) { - self.iterator = iterator + var _backing: _Backing + + @usableFromInline + init(_ backing: _Backing) { + self._backing = backing } @inlinable public mutating func next() async throws -> ByteBuffer? { - try await self.iterator.next() + switch self._backing { + case .byteBuffer(let buffer): + self._backing = .done + return buffer + + case .nioAsyncChannelRequestBody(var iterator): + let next = try await iterator.next() + self._backing = .nioAsyncChannelRequestBody(iterator) + return next + + case .anyAsyncSequence(let iterator): + return try await iterator.next() + + case .done: + return nil + } } } @@ -72,9 +104,11 @@ extension RequestBody { public func makeAsyncIterator() -> AsyncIterator { switch self._backing { case .byteBuffer(let buffer): - return .init(AnyAsyncSequence(ByteBufferRequestBody(byteBuffer: buffer)).makeAsyncIterator()) - case .stream(let stream): - return .init(stream.makeAsyncIterator()) + return .init(.byteBuffer(buffer)) + case .nioAsyncChannelRequestBody(let requestBody): + return .init(.nioAsyncChannelRequestBody(requestBody.makeAsyncIterator())) + case .anyAsyncSequence(let stream): + return .init(.anyAsyncSequence(stream.makeAsyncIterator())) } } } @@ -195,7 +229,8 @@ extension RequestBody { /// Request body that is a stream of ByteBuffers sourced from a NIOAsyncChannelInboundStream. /// /// This is a unicast async sequence that allows a single iterator to be created. -public final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { +@usableFromInline +package struct NIOAsyncChannelRequestBody: Sendable, AsyncSequence { public typealias Element = ByteBuffer public typealias InboundStream = NIOAsyncChannelInboundStream @@ -256,44 +291,3 @@ public final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { return AsyncIterator(underlyingIterator: self.underlyingIterator.wrappedValue, done: done) } } - -/// Request body stream that is a single ByteBuffer -/// -/// This is used when converting a ByteBuffer back to a stream of ByteBuffers -@usableFromInline -struct ByteBufferRequestBody: Sendable, AsyncSequence { - @usableFromInline - typealias Element = ByteBuffer - - @usableFromInline - init(byteBuffer: ByteBuffer) { - self.byteBuffer = byteBuffer - } - - @usableFromInline - struct AsyncIterator: AsyncIteratorProtocol { - @usableFromInline - var byteBuffer: ByteBuffer - @usableFromInline - var iterated: Bool - - init(byteBuffer: ByteBuffer) { - self.byteBuffer = byteBuffer - self.iterated = false - } - - @inlinable - mutating func next() async throws -> ByteBuffer? { - guard self.iterated == false else { return nil } - self.iterated = true - return self.byteBuffer - } - } - - @usableFromInline - func makeAsyncIterator() -> AsyncIterator { - .init(byteBuffer: self.byteBuffer) - } - - let byteBuffer: ByteBuffer -} diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index 0b8d8115..e61c5149 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -47,7 +47,7 @@ extension HTTPChannelHandler { while true { let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(asyncSequence: bodyStream)) + let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) let responseWriter = ResponseWriter(outbound: outbound) do { try await self.responder(request, responseWriter, asyncChannel.channel) diff --git a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift index af385b9b..3d266c4f 100644 --- a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift +++ b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift @@ -72,7 +72,7 @@ struct HTTP2StreamChannel: ServerChildChannel { throw HTTPChannelError.unexpectedHTTPPart(part) } let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(asyncSequence: bodyStream)) + let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) let responseWriter = ResponseWriter(outbound: outbound) try await self.responder(request, responseWriter, asyncChannel.channel) }