Skip to content

Commit

Permalink
Storing underlying request part iterator with RequestBody (#634)
Browse files Browse the repository at this point in the history
* Store original request body stream after it has been edited

* Add RequestBodyMergedWithUnderlyingRequestPartIterator

And use in consumeWithInboundCloseHandler

* Add tests for running consumeWIthInboundHandler after editing

* comments and tests

* re-instate fileprivate after rebase
  • Loading branch information
adam-fowler authored Dec 13, 2024
1 parent 7f08882 commit 294b620
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 31 deletions.
20 changes: 17 additions & 3 deletions Sources/HummingbirdCore/Request/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,28 @@ public struct Request: Sendable {
/// HTTP head
public let head: HTTPRequest
/// Body of HTTP request
public var body: RequestBody
private var _body: RequestBody
/// Request HTTP method
@inlinable
public var method: HTTPRequest.Method { self.head.method }
/// Request HTTP headers
@inlinable
public var headers: HTTPFields { self.head.headerFields }

public var body: RequestBody {
get { _body }
set {
let original = _body.originalRequestBody
switch newValue._backing {
case .nioAsyncChannelRequestBody:
self._body = body
case .byteBuffer(let buffer, _):
self._body = .init(.byteBuffer(buffer, original))
case .anyAsyncSequence(let seq, _):
self._body = .init(.anyAsyncSequence(seq, original))
}
}
}
// MARK: Initialization

/// Create new Request
Expand All @@ -45,7 +59,7 @@ public struct Request: Sendable {
) {
self.uri = .init(head.path ?? "")
self.head = head
self.body = body
self._body = body
}

/// Create new Request
Expand All @@ -58,7 +72,7 @@ public struct Request: Sendable {
) {
self.uri = .init(head.path ?? "")
self.head = head
self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator))
self._body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator))
}

/// Collapse body into one ByteBuffer.
Expand Down
55 changes: 34 additions & 21 deletions Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,35 @@ extension RequestBody {
_ operation: (RequestBody) async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
let iterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator> =
switch self._backing {
case .nioAsyncChannelRequestBody(let iterator):
iterator.underlyingIterator
default:
preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body")
}
let (requestBody, source) = RequestBody.makeStream()
return try await withInboundCloseHandler(
iterator: iterator.wrappedValue,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)
switch self._backing {
case .nioAsyncChannelRequestBody(let body):
return try await withInboundCloseHandler(
iterator: body.underlyingIterator.wrappedValue,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)

case .byteBuffer(_, .some(let originalRequestBody)), .anyAsyncSequence(_, .some(let originalRequestBody)):
let iterator =
self
.mergeWithUnderlyingRequestPartIterator(originalRequestBody.underlyingIterator.wrappedValue)
.makeAsyncIterator()
return try await withInboundCloseHandler(
iterator: iterator,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)

default:
preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body")
}
}

/// Run provided closure but cancel it if the inbound request part stream is closed.
Expand Down Expand Up @@ -94,13 +107,13 @@ extension RequestBody {
}
}

fileprivate func withInboundCloseHandler<Value: Sendable>(
fileprivate func withInboundCloseHandler<Value: Sendable, AsyncIterator: AsyncIteratorProtocol>(
isolation: isolated (any Actor)? = #isolation,
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
iterator: AsyncIterator,
source: RequestBody.Source,
operation: () async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
) async throws -> Value where AsyncIterator.Element == HTTPRequestPart {
let unsafeIterator = UnsafeTransfer(iterator)
let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed)
let value = try await withThrowingTaskGroup(of: Void.self) { group in
Expand All @@ -123,10 +136,10 @@ extension RequestBody {
case nextRequestReady
}

fileprivate func iterate(
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
fileprivate func iterate<AsyncIterator: AsyncIteratorProtocol>(
iterator: AsyncIterator,
source: RequestBody.Source
) async throws -> IterateResult {
) async throws -> IterateResult where AsyncIterator.Element == HTTPRequestPart {
var iterator = iterator
while let part = try await iterator.next() {
switch part {
Expand Down
22 changes: 15 additions & 7 deletions Sources/HummingbirdCore/Request/RequestBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import NIOHTTPTypes
public struct RequestBody: Sendable, AsyncSequence {
@usableFromInline
internal enum _Backing: Sendable {
case byteBuffer(ByteBuffer)
case byteBuffer(ByteBuffer, NIOAsyncChannelRequestBody?)
case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>, NIOAsyncChannelRequestBody?)
}

@usableFromInline
Expand All @@ -40,7 +40,7 @@ public struct RequestBody: Sendable, AsyncSequence {
/// - Parameter buffer: ByteBuffer
@inlinable
public init(buffer: ByteBuffer) {
self.init(.byteBuffer(buffer))
self.init(.byteBuffer(buffer, nil))
}

/// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers
Expand All @@ -54,7 +54,7 @@ public struct RequestBody: Sendable, AsyncSequence {
/// - Parameter asyncSequence: AsyncSequence
@inlinable
public init<AS: AsyncSequence & Sendable>(asyncSequence: AS) where AS.Element == ByteBuffer {
self.init(.anyAsyncSequence(.init(asyncSequence)))
self.init(.anyAsyncSequence(.init(asyncSequence), nil))
}
}

Expand Down Expand Up @@ -103,14 +103,22 @@ extension RequestBody {
@inlinable
public func makeAsyncIterator() -> AsyncIterator {
switch self._backing {
case .byteBuffer(let buffer):
case .byteBuffer(let buffer, _):
return .init(.byteBuffer(buffer))
case .nioAsyncChannelRequestBody(let requestBody):
return .init(.nioAsyncChannelRequestBody(requestBody.makeAsyncIterator()))
case .anyAsyncSequence(let stream):
case .anyAsyncSequence(let stream, _):
return .init(.anyAsyncSequence(stream.makeAsyncIterator()))
}
}

var originalRequestBody: NIOAsyncChannelRequestBody? {
switch _backing {
case .nioAsyncChannelRequestBody(let body): body
case .byteBuffer(_, let body): body
case .anyAsyncSequence: nil
}
}
}

/// Extend RequestBody to create request body streams backed by `NIOThrowingAsyncSequenceProducer`.
Expand Down Expand Up @@ -235,7 +243,7 @@ package struct NIOAsyncChannelRequestBody: Sendable, AsyncSequence {
public typealias InboundStream = NIOAsyncChannelInboundStream<HTTPRequestPart>

@usableFromInline
internal let underlyingIterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator>
internal let underlyingIterator: UnsafeTransfer<InboundStream.AsyncIterator>
@usableFromInline
internal let alreadyIterated: NIOLockedValueBox<Bool>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore
import NIOHTTPTypes

/// AsyncSequence used by consumeWithInboundCloseHandler
///
/// It will provide the buffers output by the ResponseBody and when that finishes will start
/// iterating what is left of the underlying request part stream, and continue iterating until
/// it hits the next head
struct RequestBodyMergedWithUnderlyingRequestPartIterator<Base: AsyncSequence>: AsyncSequence where Base.Element == ByteBuffer {
typealias Element = HTTPRequestPart
let base: Base
let underlyingIterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator

struct AsyncIterator: AsyncIteratorProtocol {
enum CurrentAsyncIterator {
case base(Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator)
case underlying(NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator)
case done
}
var current: CurrentAsyncIterator

init(iterator: Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator) {
self.current = .base(iterator, underlying: underlying)
}

mutating func next() async throws -> HTTPRequestPart? {
switch self.current {
case .base(var base, let underlying):
if let element = try await base.next() {
self.current = .base(base, underlying: underlying)
return .body(element)
} else {
self.current = .underlying(underlying)
return .end(nil)
}

case .underlying(var underlying):
while true {
let part = try await underlying.next()
if case .head = part {
self.current = .done
return part
}
}
self.current = .underlying(underlying)
return nil

case .done:
return nil
}
}
}

func makeAsyncIterator() -> AsyncIterator {
.init(iterator: base.makeAsyncIterator(), underlying: underlyingIterator)
}
}

extension RequestBody {
func mergeWithUnderlyingRequestPartIterator(
_ iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
) -> RequestBodyMergedWithUnderlyingRequestPartIterator<Self> {
.init(base: self, underlyingIterator: iterator)
}
}
66 changes: 66 additions & 0 deletions Tests/HummingbirdTests/ApplicationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,72 @@ final class ApplicationTests: XCTestCase {
}
}
}

/// Test consumeWithInboundHandler after having collected the Request body
@available(macOS 15, iOS 18, tvOS 18, *)
func testConsumeWithInboundHandlerAfterCollect() async throws {
let router = Router()
router.post("streaming") { request, context -> Response in
var request = request
_ = try await request.collectBody(upTo: .max)
let request2 = request
return Response(
status: .ok,
body: .init { writer in
try await request2.body.consumeWithInboundCloseHandler { body in
try await writer.write(body)
} onInboundClosed: {
}
try await writer.finish(nil)
}
)
}
let app = Application(responder: router.buildResponder())

try await app.test(.live) { client in
let buffer = Self.randomBuffer(size: 640_001)
try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in
XCTAssertEqual(response.status, .ok)
XCTAssertEqual(response.body, buffer)
}
}
}

/// Test consumeWithInboundHandler after having replaced Request.body with a new streamed RequestBody
@available(macOS 15, iOS 18, tvOS 18, *)
func testConsumeWithInboundHandlerAfterReplacingBody() async throws {
let router = Router()
router.post("streaming") { request, context -> Response in
var request = request
request.body = .init(
asyncSequence: request.body.map {
let view = $0.readableBytesView.map { $0 ^ 255 }
return ByteBuffer(bytes: view)
}
)
let request2 = request
return Response(
status: .ok,
body: .init { writer in
try await request2.body.consumeWithInboundCloseHandler { body in
try await writer.write(body)
} onInboundClosed: {
}
try await writer.finish(nil)
}
)
}
let app = Application(responder: router.buildResponder())

try await app.test(.live) { client in
let buffer = Self.randomBuffer(size: 640_001)
let xorBuffer = ByteBuffer(bytes: buffer.readableBytesView.map { $0 ^ 255 })
try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in
XCTAssertEqual(response.status, .ok)
XCTAssertEqual(response.body, xorBuffer)
}
}
}
#endif
}

Expand Down

0 comments on commit 294b620

Please sign in to comment.