From 294b62034e53265cd4af4ceab54ce425a9456666 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 13 Dec 2024 15:27:55 +0000 Subject: [PATCH] Storing underlying request part iterator with RequestBody (#634) * Store original request body stream after it has been edited * Add RequestBodyMergedWithUnderlyingRequestPartIterator And use in consumeWithInboundCloseHandler * Add tests for running consumeWIthInboundHandler after editing * comments and tests * re-instate fileprivate after rebase --- Sources/HummingbirdCore/Request/Request.swift | 20 ++++- .../Request/RequestBody+inboundClose.swift | 55 ++++++++----- .../HummingbirdCore/Request/RequestBody.swift | 22 ++++-- ...gedWithUnderlyingRequestPartIterator.swift | 79 +++++++++++++++++++ Tests/HummingbirdTests/ApplicationTests.swift | 66 ++++++++++++++++ 5 files changed, 211 insertions(+), 31 deletions(-) create mode 100644 Sources/HummingbirdCore/Request/RequestBodyMergedWithUnderlyingRequestPartIterator.swift diff --git a/Sources/HummingbirdCore/Request/Request.swift b/Sources/HummingbirdCore/Request/Request.swift index 9b27df89..e21694e7 100644 --- a/Sources/HummingbirdCore/Request/Request.swift +++ b/Sources/HummingbirdCore/Request/Request.swift @@ -25,7 +25,7 @@ public struct Request: Sendable { /// HTTP head public let head: HTTPRequest /// Body of HTTP request - public var body: RequestBody + private var _body: RequestBody /// Request HTTP method @inlinable public var method: HTTPRequest.Method { self.head.method } @@ -33,6 +33,20 @@ public struct Request: Sendable { @inlinable public var headers: HTTPFields { self.head.headerFields } + public var body: RequestBody { + get { _body } + set { + let original = _body.originalRequestBody + switch newValue._backing { + case .nioAsyncChannelRequestBody: + self._body = body + case .byteBuffer(let buffer, _): + self._body = .init(.byteBuffer(buffer, original)) + case .anyAsyncSequence(let seq, _): + self._body = .init(.anyAsyncSequence(seq, original)) + } + } + } // MARK: Initialization /// Create new Request @@ -45,7 +59,7 @@ public struct Request: Sendable { ) { self.uri = .init(head.path ?? "") self.head = head - self.body = body + self._body = body } /// Create new Request @@ -58,7 +72,7 @@ public struct Request: Sendable { ) { self.uri = .init(head.path ?? "") self.head = head - self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator)) + self._body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator)) } /// Collapse body into one ByteBuffer. diff --git a/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift b/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift index dab3220e..2bf9bb2c 100644 --- a/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift +++ b/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift @@ -37,22 +37,35 @@ extension RequestBody { _ operation: (RequestBody) async throws -> Value, onInboundClosed: @Sendable @escaping () -> Void ) async throws -> Value { - let iterator: UnsafeTransfer.AsyncIterator> = - switch self._backing { - case .nioAsyncChannelRequestBody(let iterator): - iterator.underlyingIterator - default: - preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body") - } let (requestBody, source) = RequestBody.makeStream() - return try await withInboundCloseHandler( - iterator: iterator.wrappedValue, - source: source, - operation: { - try await operation(requestBody) - }, - onInboundClosed: onInboundClosed - ) + switch self._backing { + case .nioAsyncChannelRequestBody(let body): + return try await withInboundCloseHandler( + iterator: body.underlyingIterator.wrappedValue, + source: source, + operation: { + try await operation(requestBody) + }, + onInboundClosed: onInboundClosed + ) + + case .byteBuffer(_, .some(let originalRequestBody)), .anyAsyncSequence(_, .some(let originalRequestBody)): + let iterator = + self + .mergeWithUnderlyingRequestPartIterator(originalRequestBody.underlyingIterator.wrappedValue) + .makeAsyncIterator() + return try await withInboundCloseHandler( + iterator: iterator, + source: source, + operation: { + try await operation(requestBody) + }, + onInboundClosed: onInboundClosed + ) + + default: + preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body") + } } /// Run provided closure but cancel it if the inbound request part stream is closed. @@ -94,13 +107,13 @@ extension RequestBody { } } - fileprivate func withInboundCloseHandler( + fileprivate func withInboundCloseHandler( isolation: isolated (any Actor)? = #isolation, - iterator: NIOAsyncChannelInboundStream.AsyncIterator, + iterator: AsyncIterator, source: RequestBody.Source, operation: () async throws -> Value, onInboundClosed: @Sendable @escaping () -> Void - ) async throws -> Value { + ) async throws -> Value where AsyncIterator.Element == HTTPRequestPart { let unsafeIterator = UnsafeTransfer(iterator) let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed) let value = try await withThrowingTaskGroup(of: Void.self) { group in @@ -123,10 +136,10 @@ extension RequestBody { case nextRequestReady } - fileprivate func iterate( - iterator: NIOAsyncChannelInboundStream.AsyncIterator, + fileprivate func iterate( + iterator: AsyncIterator, source: RequestBody.Source - ) async throws -> IterateResult { + ) async throws -> IterateResult where AsyncIterator.Element == HTTPRequestPart { var iterator = iterator while let part = try await iterator.next() { switch part { diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index c49bf13e..082d3a4f 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -23,9 +23,9 @@ import NIOHTTPTypes public struct RequestBody: Sendable, AsyncSequence { @usableFromInline internal enum _Backing: Sendable { - case byteBuffer(ByteBuffer) + case byteBuffer(ByteBuffer, NIOAsyncChannelRequestBody?) case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody) - case anyAsyncSequence(AnyAsyncSequence) + case anyAsyncSequence(AnyAsyncSequence, NIOAsyncChannelRequestBody?) } @usableFromInline @@ -40,7 +40,7 @@ public struct RequestBody: Sendable, AsyncSequence { /// - Parameter buffer: ByteBuffer @inlinable public init(buffer: ByteBuffer) { - self.init(.byteBuffer(buffer)) + self.init(.byteBuffer(buffer, nil)) } /// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers @@ -54,7 +54,7 @@ public struct RequestBody: Sendable, AsyncSequence { /// - Parameter asyncSequence: AsyncSequence @inlinable public init(asyncSequence: AS) where AS.Element == ByteBuffer { - self.init(.anyAsyncSequence(.init(asyncSequence))) + self.init(.anyAsyncSequence(.init(asyncSequence), nil)) } } @@ -103,14 +103,22 @@ extension RequestBody { @inlinable public func makeAsyncIterator() -> AsyncIterator { switch self._backing { - case .byteBuffer(let buffer): + case .byteBuffer(let buffer, _): return .init(.byteBuffer(buffer)) case .nioAsyncChannelRequestBody(let requestBody): return .init(.nioAsyncChannelRequestBody(requestBody.makeAsyncIterator())) - case .anyAsyncSequence(let stream): + case .anyAsyncSequence(let stream, _): return .init(.anyAsyncSequence(stream.makeAsyncIterator())) } } + + var originalRequestBody: NIOAsyncChannelRequestBody? { + switch _backing { + case .nioAsyncChannelRequestBody(let body): body + case .byteBuffer(_, let body): body + case .anyAsyncSequence: nil + } + } } /// Extend RequestBody to create request body streams backed by `NIOThrowingAsyncSequenceProducer`. @@ -235,7 +243,7 @@ package struct NIOAsyncChannelRequestBody: Sendable, AsyncSequence { public typealias InboundStream = NIOAsyncChannelInboundStream @usableFromInline - internal let underlyingIterator: UnsafeTransfer.AsyncIterator> + internal let underlyingIterator: UnsafeTransfer @usableFromInline internal let alreadyIterated: NIOLockedValueBox diff --git a/Sources/HummingbirdCore/Request/RequestBodyMergedWithUnderlyingRequestPartIterator.swift b/Sources/HummingbirdCore/Request/RequestBodyMergedWithUnderlyingRequestPartIterator.swift new file mode 100644 index 00000000..719c1dd6 --- /dev/null +++ b/Sources/HummingbirdCore/Request/RequestBodyMergedWithUnderlyingRequestPartIterator.swift @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOHTTPTypes + +/// AsyncSequence used by consumeWithInboundCloseHandler +/// +/// It will provide the buffers output by the ResponseBody and when that finishes will start +/// iterating what is left of the underlying request part stream, and continue iterating until +/// it hits the next head +struct RequestBodyMergedWithUnderlyingRequestPartIterator: AsyncSequence where Base.Element == ByteBuffer { + typealias Element = HTTPRequestPart + let base: Base + let underlyingIterator: NIOAsyncChannelInboundStream.AsyncIterator + + struct AsyncIterator: AsyncIteratorProtocol { + enum CurrentAsyncIterator { + case base(Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream.AsyncIterator) + case underlying(NIOAsyncChannelInboundStream.AsyncIterator) + case done + } + var current: CurrentAsyncIterator + + init(iterator: Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream.AsyncIterator) { + self.current = .base(iterator, underlying: underlying) + } + + mutating func next() async throws -> HTTPRequestPart? { + switch self.current { + case .base(var base, let underlying): + if let element = try await base.next() { + self.current = .base(base, underlying: underlying) + return .body(element) + } else { + self.current = .underlying(underlying) + return .end(nil) + } + + case .underlying(var underlying): + while true { + let part = try await underlying.next() + if case .head = part { + self.current = .done + return part + } + } + self.current = .underlying(underlying) + return nil + + case .done: + return nil + } + } + } + + func makeAsyncIterator() -> AsyncIterator { + .init(iterator: base.makeAsyncIterator(), underlying: underlyingIterator) + } +} + +extension RequestBody { + func mergeWithUnderlyingRequestPartIterator( + _ iterator: NIOAsyncChannelInboundStream.AsyncIterator + ) -> RequestBodyMergedWithUnderlyingRequestPartIterator { + .init(base: self, underlyingIterator: iterator) + } +} diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index a8c26a91..1d75baa7 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -893,6 +893,72 @@ final class ApplicationTests: XCTestCase { } } } + + /// Test consumeWithInboundHandler after having collected the Request body + @available(macOS 15, iOS 18, tvOS 18, *) + func testConsumeWithInboundHandlerAfterCollect() async throws { + let router = Router() + router.post("streaming") { request, context -> Response in + var request = request + _ = try await request.collectBody(upTo: .max) + let request2 = request + return Response( + status: .ok, + body: .init { writer in + try await request2.body.consumeWithInboundCloseHandler { body in + try await writer.write(body) + } onInboundClosed: { + } + try await writer.finish(nil) + } + ) + } + let app = Application(responder: router.buildResponder()) + + try await app.test(.live) { client in + let buffer = Self.randomBuffer(size: 640_001) + try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.body, buffer) + } + } + } + + /// Test consumeWithInboundHandler after having replaced Request.body with a new streamed RequestBody + @available(macOS 15, iOS 18, tvOS 18, *) + func testConsumeWithInboundHandlerAfterReplacingBody() async throws { + let router = Router() + router.post("streaming") { request, context -> Response in + var request = request + request.body = .init( + asyncSequence: request.body.map { + let view = $0.readableBytesView.map { $0 ^ 255 } + return ByteBuffer(bytes: view) + } + ) + let request2 = request + return Response( + status: .ok, + body: .init { writer in + try await request2.body.consumeWithInboundCloseHandler { body in + try await writer.write(body) + } onInboundClosed: { + } + try await writer.finish(nil) + } + ) + } + let app = Application(responder: router.buildResponder()) + + try await app.test(.live) { client in + let buffer = Self.randomBuffer(size: 640_001) + let xorBuffer = ByteBuffer(bytes: buffer.readableBytesView.map { $0 ^ 255 }) + try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.body, xorBuffer) + } + } + } #endif }