diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index 5a0600b9..c49bf13e 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -125,7 +125,7 @@ extension RequestBody { /// Delegate for NIOThrowingAsyncSequenceProducer @usableFromInline - final class Delegate: NIOAsyncSequenceProducerDelegate { + final class Delegate: NIOAsyncSequenceProducerDelegate, Sendable { let checkedContinuations: NIOLockedValueBox>> @usableFromInline @@ -162,13 +162,13 @@ extension RequestBody { } /// A source used for driving a ``RequestBody`` stream. - public final class Source { + public final class Source: Sendable { @usableFromInline let source: Producer.Source @usableFromInline let delegate: Delegate @usableFromInline - var waitForProduceMore: Bool + let waitForProduceMore: NIOLockedValueBox @usableFromInline init(source: Producer.Source, delegate: Delegate) { @@ -187,13 +187,13 @@ extension RequestBody { public func yield(_ element: ByteBuffer) async throws { // if previous call indicated we should stop producing wait until the delegate // says we can start producing again - if self.waitForProduceMore { + if self.waitForProduceMore.withLockedValue({ $0 }) { await self.delegate.waitForProduceMore() - self.waitForProduceMore = false + self.waitForProduceMore.withLockedValue { $0 = false } } let result = self.source.yield(element) if result == .stopProducing { - self.waitForProduceMore = true + self.waitForProduceMore.withLockedValue { $0 = true } } } diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index c13a1b92..3063a408 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -833,6 +833,39 @@ final class ApplicationTests: XCTestCase { XCTAssertEqual(format.error.message, message) } } + + /// Test AsyncSequence returned by RequestBody.makeStream() + func testMakeStream() async throws { + let router = Router() + router.post("streaming") { request, context -> Response in + let body = try await withThrowingTaskGroup(of: Void.self) { group in + let (requestBody, source) = RequestBody.makeStream() + group.addTask { + for try await buffer in request.body { + try await source.yield(buffer) + } + source.finish() + } + var body = ByteBuffer() + for try await buffer in requestBody { + var buffer = buffer + body.writeBuffer(&buffer) + } + return body + } + return Response(status: .ok, body: .init(byteBuffer: body)) + } + let app = Application(responder: router.buildResponder()) + + try await app.test(.router) { client in + + let buffer = Self.randomBuffer(size: 640_001) + try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.body, buffer) + } + } + } } /// HTTPField used during tests