Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storing underlying request part iterator with RequestBody #634

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions Sources/HummingbirdCore/Request/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,28 @@ public struct Request: Sendable {
/// HTTP head
public let head: HTTPRequest
/// Body of HTTP request
public var body: RequestBody
private var _body: RequestBody
/// Request HTTP method
@inlinable
public var method: HTTPRequest.Method { self.head.method }
/// Request HTTP headers
@inlinable
public var headers: HTTPFields { self.head.headerFields }

public var body: RequestBody {
get { _body }
set {
let original = _body.originalRequestBody
switch newValue._backing {
case .nioAsyncChannelRequestBody:
self._body = body
case .byteBuffer(let buffer, _):
self._body = .init(.byteBuffer(buffer, original))
case .anyAsyncSequence(let seq, _):
self._body = .init(.anyAsyncSequence(seq, original))
}
}
}
// MARK: Initialization

/// Create new Request
Expand All @@ -45,7 +59,7 @@ public struct Request: Sendable {
) {
self.uri = .init(head.path ?? "")
self.head = head
self.body = body
self._body = body
}

/// Create new Request
Expand All @@ -58,7 +72,7 @@ public struct Request: Sendable {
) {
self.uri = .init(head.path ?? "")
self.head = head
self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator))
self._body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator))
}

/// Collapse body into one ByteBuffer.
Expand Down
55 changes: 34 additions & 21 deletions Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,35 @@ extension RequestBody {
_ operation: (RequestBody) async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
let iterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator> =
switch self._backing {
case .nioAsyncChannelRequestBody(let iterator):
iterator.underlyingIterator
default:
preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body")
}
let (requestBody, source) = RequestBody.makeStream()
return try await withInboundCloseHandler(
iterator: iterator.wrappedValue,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)
switch self._backing {
case .nioAsyncChannelRequestBody(let body):
return try await withInboundCloseHandler(
iterator: body.underlyingIterator.wrappedValue,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)

case .byteBuffer(_, .some(let originalRequestBody)), .anyAsyncSequence(_, .some(let originalRequestBody)):
let iterator =
self
.mergeWithUnderlyingRequestPartIterator(originalRequestBody.underlyingIterator.wrappedValue)
.makeAsyncIterator()
return try await withInboundCloseHandler(
iterator: iterator,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)

default:
preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body")
}
}

/// Run provided closure but cancel it if the inbound request part stream is closed.
Expand Down Expand Up @@ -94,13 +107,13 @@ extension RequestBody {
}
}

fileprivate func withInboundCloseHandler<Value: Sendable>(
fileprivate func withInboundCloseHandler<Value: Sendable, AsyncIterator: AsyncIteratorProtocol>(
isolation: isolated (any Actor)? = #isolation,
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
iterator: AsyncIterator,
source: RequestBody.Source,
operation: () async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
) async throws -> Value where AsyncIterator.Element == HTTPRequestPart {
let unsafeIterator = UnsafeTransfer(iterator)
let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed)
let value = try await withThrowingTaskGroup(of: Void.self) { group in
Expand All @@ -123,10 +136,10 @@ extension RequestBody {
case nextRequestReady
}

fileprivate func iterate(
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
fileprivate func iterate<AsyncIterator: AsyncIteratorProtocol>(
iterator: AsyncIterator,
source: RequestBody.Source
) async throws -> IterateResult {
) async throws -> IterateResult where AsyncIterator.Element == HTTPRequestPart {
var iterator = iterator
while let part = try await iterator.next() {
switch part {
Expand Down
22 changes: 15 additions & 7 deletions Sources/HummingbirdCore/Request/RequestBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import NIOHTTPTypes
public struct RequestBody: Sendable, AsyncSequence {
@usableFromInline
internal enum _Backing: Sendable {
case byteBuffer(ByteBuffer)
case byteBuffer(ByteBuffer, NIOAsyncChannelRequestBody?)
case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>, NIOAsyncChannelRequestBody?)
}

@usableFromInline
Expand All @@ -40,7 +40,7 @@ public struct RequestBody: Sendable, AsyncSequence {
/// - Parameter buffer: ByteBuffer
@inlinable
public init(buffer: ByteBuffer) {
self.init(.byteBuffer(buffer))
self.init(.byteBuffer(buffer, nil))
}

/// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers
Expand All @@ -54,7 +54,7 @@ public struct RequestBody: Sendable, AsyncSequence {
/// - Parameter asyncSequence: AsyncSequence
@inlinable
public init<AS: AsyncSequence & Sendable>(asyncSequence: AS) where AS.Element == ByteBuffer {
self.init(.anyAsyncSequence(.init(asyncSequence)))
self.init(.anyAsyncSequence(.init(asyncSequence), nil))
}
}

Expand Down Expand Up @@ -103,14 +103,22 @@ extension RequestBody {
@inlinable
public func makeAsyncIterator() -> AsyncIterator {
switch self._backing {
case .byteBuffer(let buffer):
case .byteBuffer(let buffer, _):
return .init(.byteBuffer(buffer))
case .nioAsyncChannelRequestBody(let requestBody):
return .init(.nioAsyncChannelRequestBody(requestBody.makeAsyncIterator()))
case .anyAsyncSequence(let stream):
case .anyAsyncSequence(let stream, _):
return .init(.anyAsyncSequence(stream.makeAsyncIterator()))
}
}

var originalRequestBody: NIOAsyncChannelRequestBody? {
switch _backing {
case .nioAsyncChannelRequestBody(let body): body
case .byteBuffer(_, let body): body
case .anyAsyncSequence: nil
}
}
}

/// Extend RequestBody to create request body streams backed by `NIOThrowingAsyncSequenceProducer`.
Expand Down Expand Up @@ -235,7 +243,7 @@ package struct NIOAsyncChannelRequestBody: Sendable, AsyncSequence {
public typealias InboundStream = NIOAsyncChannelInboundStream<HTTPRequestPart>

@usableFromInline
internal let underlyingIterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator>
internal let underlyingIterator: UnsafeTransfer<InboundStream.AsyncIterator>
@usableFromInline
internal let alreadyIterated: NIOLockedValueBox<Bool>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore
import NIOHTTPTypes

/// AsyncSequence used by consumeWithInboundCloseHandler
///
/// It will provide the buffers output by the ResponseBody and when that finishes will start
/// iterating what is left of the underlying request part stream, and continue iterating until
/// it hits the next head
struct RequestBodyMergedWithUnderlyingRequestPartIterator<Base: AsyncSequence>: AsyncSequence where Base.Element == ByteBuffer {
typealias Element = HTTPRequestPart
let base: Base
let underlyingIterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator

struct AsyncIterator: AsyncIteratorProtocol {
enum CurrentAsyncIterator {
case base(Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator)
case underlying(NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator)
case done
}
var current: CurrentAsyncIterator

init(iterator: Base.AsyncIterator, underlying: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator) {
self.current = .base(iterator, underlying: underlying)
}

mutating func next() async throws -> HTTPRequestPart? {
switch self.current {
case .base(var base, let underlying):
if let element = try await base.next() {
self.current = .base(base, underlying: underlying)
return .body(element)
} else {
self.current = .underlying(underlying)
return .end(nil)
}

case .underlying(var underlying):
while true {
let part = try await underlying.next()
if case .head = part {
self.current = .done
return part
}
}
self.current = .underlying(underlying)
return nil

case .done:
return nil
}
}
}

func makeAsyncIterator() -> AsyncIterator {
.init(iterator: base.makeAsyncIterator(), underlying: underlyingIterator)
}
}

extension RequestBody {
func mergeWithUnderlyingRequestPartIterator(
_ iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
) -> RequestBodyMergedWithUnderlyingRequestPartIterator<Self> {
.init(base: self, underlyingIterator: iterator)
}
}
66 changes: 66 additions & 0 deletions Tests/HummingbirdTests/ApplicationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,72 @@ final class ApplicationTests: XCTestCase {
}
}
}

/// Test consumeWithInboundHandler after having collected the Request body
@available(macOS 15, iOS 18, tvOS 18, *)
func testConsumeWithInboundHandlerAfterCollect() async throws {
let router = Router()
router.post("streaming") { request, context -> Response in
var request = request
_ = try await request.collectBody(upTo: .max)
let request2 = request
return Response(
status: .ok,
body: .init { writer in
try await request2.body.consumeWithInboundCloseHandler { body in
try await writer.write(body)
} onInboundClosed: {
adam-fowler marked this conversation as resolved.
Show resolved Hide resolved
}
try await writer.finish(nil)
}
)
}
let app = Application(responder: router.buildResponder())

try await app.test(.live) { client in
let buffer = Self.randomBuffer(size: 640_001)
try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in
XCTAssertEqual(response.status, .ok)
XCTAssertEqual(response.body, buffer)
}
}
}

/// Test consumeWithInboundHandler after having replaced Request.body with a new streamed RequestBody
@available(macOS 15, iOS 18, tvOS 18, *)
func testConsumeWithInboundHandlerAfterReplacingBody() async throws {
let router = Router()
router.post("streaming") { request, context -> Response in
var request = request
request.body = .init(
asyncSequence: request.body.map {
let view = $0.readableBytesView.map { $0 ^ 255 }
return ByteBuffer(bytes: view)
}
)
let request2 = request
return Response(
status: .ok,
body: .init { writer in
try await request2.body.consumeWithInboundCloseHandler { body in
try await writer.write(body)
} onInboundClosed: {
adam-fowler marked this conversation as resolved.
Show resolved Hide resolved
}
try await writer.finish(nil)
}
)
}
let app = Application(responder: router.buildResponder())

try await app.test(.live) { client in
let buffer = Self.randomBuffer(size: 640_001)
let xorBuffer = ByteBuffer(bytes: buffer.readableBytesView.map { $0 ^ 255 })
try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in
XCTAssertEqual(response.status, .ok)
XCTAssertEqual(response.body, xorBuffer)
}
}
}
#endif
}

Expand Down
Loading