Skip to content

Commit

Permalink
Cancel on inbound close (version 2) (#631)
Browse files Browse the repository at this point in the history
* Initial attempt at cancel on inbound close

* Add testCancelOnCloseInboundInResponseWriter

* Rename to testCancelOnCloseInbound

* Move code to separate file, only available for swift 6

* next(isolation: self)

* Use a sending closure for cancelOnInboundClose

* Add supportCancelOnInboundClosure to HTTP1Channel.Configuration

* Add withInboundCloseHandler

* Inbound close handler on RequestBody

* Don't allow consumeWithInboundCloseHandler to run on edited RequestBody

* Edit comments

* Remove unnecessary changes

* Simplify consumeWithCancelOnInboundClose

* onInboundClosed is Sendable so don't need to wrap it

* Use isolated to set isolation of consumeWithInboundCloseHandler

* Add application test

* Revert consumeWithCancellationOnInboundClose to take sending closure

* Remove extraneous print statement

* Remove unnecessary platform availability checks

* Fail if handler isnt cancelled in testCancelOnCloseInbound
  • Loading branch information
adam-fowler authored Dec 13, 2024
1 parent b8ca522 commit 7f08882
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 16 deletions.
16 changes: 15 additions & 1 deletion Sources/HummingbirdCore/Request/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2022 the Hummingbird authors
// Copyright (c) 2021-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -14,6 +14,7 @@

import HTTPTypes
import NIOCore
import NIOHTTPTypes

/// Holds all the values required to process a request
public struct Request: Sendable {
Expand Down Expand Up @@ -47,6 +48,19 @@ public struct Request: Sendable {
self.body = body
}

/// Create new Request
/// - Parameters:
/// - head: HTTP head
/// - bodyIterator: HTTP request part stream
package init(
head: HTTPRequest,
bodyIterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
) {
self.uri = .init(head.path ?? "")
self.head = head
self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator))
}

/// Collapse body into one ByteBuffer.
///
/// This will store the collated ByteBuffer back into the request so is a mutating method. If
Expand Down
144 changes: 144 additions & 0 deletions Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTPTypes

#if compiler(>=6.0)
extension RequestBody {
/// Run provided closure but cancel it if the inbound request part stream is closed.
///
/// This function is designed for use with long running requests like server sent events. It assumes you
/// are not going to be using the request body after calling as it consumes the request body, it also assumes
/// you havent edited the request body prior to calling this function.
///
/// If the response finishes the connection will be closed.
///
/// - Parameters
/// - isolation: The isolation of the method. Defaults to the isolation of the caller.
/// - operation: The actual operation
/// = onInboundClose: handler invoked when inbound is closed
/// - Returns: Return value of operation
public func consumeWithInboundCloseHandler<Value: Sendable>(
isolation: isolated (any Actor)? = #isolation,
_ operation: (RequestBody) async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
let iterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator> =
switch self._backing {
case .nioAsyncChannelRequestBody(let iterator):
iterator.underlyingIterator
default:
preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body")
}
let (requestBody, source) = RequestBody.makeStream()
return try await withInboundCloseHandler(
iterator: iterator.wrappedValue,
source: source,
operation: {
try await operation(requestBody)
},
onInboundClosed: onInboundClosed
)
}

/// Run provided closure but cancel it if the inbound request part stream is closed.
///
/// This function is designed for use with long running requests like server sent events. It assumes you
/// are not going to be using the request body after calling as it consumes the request body, it also assumes
/// you havent edited the request body prior to calling this function.
///
/// If the response finishes the connection will be closed.
///
/// - Parameters
/// - isolation: The isolation of the method. Defaults to the isolation of the caller.
/// - operation: The actual operation to run
/// - Returns: Return value of operation
public func consumeWithCancellationOnInboundClose<Value: Sendable>(
_ operation: sending @escaping (RequestBody) async throws -> Value
) async throws -> Value {
let (barrier, source) = AsyncStream<Void>.makeStream()
return try await consumeWithInboundCloseHandler { body in
try await withThrowingTaskGroup(of: Value.self) { group in
let unsafeOperation = UnsafeTransfer(operation)
group.addTask {
var iterator = barrier.makeAsyncIterator()
_ = await iterator.next()
throw CancellationError()
}
group.addTask {
try await unsafeOperation.wrappedValue(body)
}
if case .some(let value) = try await group.next() {
source.finish()
return value
}
group.cancelAll()
throw CancellationError()
}
} onInboundClosed: {
source.finish()
}
}

fileprivate func withInboundCloseHandler<Value: Sendable>(
isolation: isolated (any Actor)? = #isolation,
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
source: RequestBody.Source,
operation: () async throws -> Value,
onInboundClosed: @Sendable @escaping () -> Void
) async throws -> Value {
let unsafeIterator = UnsafeTransfer(iterator)
let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed)
let value = try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
do {
if try await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .inboundClosed {
unsafeOnInboundClosed.wrappedValue()
}
} catch is CancellationError {}
}
let value = try await operation()
group.cancelAll()
return value
}
return value
}

fileprivate enum IterateResult {
case inboundClosed
case nextRequestReady
}

fileprivate func iterate(
iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
source: RequestBody.Source
) async throws -> IterateResult {
var iterator = iterator
while let part = try await iterator.next() {
switch part {
case .head:
return .nextRequestReady
case .body(let buffer):
try await source.yield(buffer)
case .end:
source.finish()
}
}
return .inboundClosed
}
}
#endif // compiler(>=6.0)
7 changes: 5 additions & 2 deletions Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ServiceLifecycle
/// Protocol for HTTP channels
public protocol HTTPChannelHandler: ServerChildChannel {
typealias Responder = @Sendable (Request, consuming ResponseWriter, Channel) async throws -> Void
/// HTTP Request responder
var responder: Responder { get }
}

Expand All @@ -46,8 +47,10 @@ extension HTTPChannelHandler {
}

while true {
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream))
let request = Request(
head: head,
bodyIterator: iterator
)
let responseWriter = ResponseWriter(outbound: outbound)
do {
try await self.responder(request, responseWriter, asyncChannel.channel)
Expand Down
6 changes: 4 additions & 2 deletions Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ struct HTTP2StreamChannel: ServerChildChannel {
guard case .head(let head) = part else {
throw HTTPChannelError.unexpectedHTTPPart(part)
}
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream))
let request = Request(
head: head,
bodyIterator: iterator
)
let responseWriter = ResponseWriter(outbound: outbound)
try await self.responder(request, responseWriter, asyncChannel.channel)
}
Expand Down
14 changes: 11 additions & 3 deletions Sources/HummingbirdTesting/TestClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public struct TestClient: Sendable {
/// shutdown client
public func shutdown() async throws {
do {
try await self.close()
try await self.close(mode: .all)
} catch TestClient.Error.connectionNotOpen {
} catch ChannelError.alreadyClosed {}
if case .createNew = self.eventLoopGroupProvider {
Expand Down Expand Up @@ -163,10 +163,18 @@ public struct TestClient: Sendable {
return response
}

public func close() async throws {
/// Execute request to server but don't wait for response.
public func executeAndDontWaitForResponse(_ request: TestClient.Request) async throws {
let channel = try await getChannel()
let promise = self.eventLoopGroup.any().makePromise(of: TestClient.Response.self)
let task = HTTPTask(request: self.cleanupRequest(request), responsePromise: promise)
try await channel.writeAndFlush(task)
}

public func close(mode: CloseMode = .output) async throws {
self.channelPromise.completeWith(.failure(TestClient.Error.connectionNotOpen))
let channel = try await getChannel()
return try await channel.close()
return try await channel.close(mode: mode)
}

public func getChannel() async throws -> Channel {
Expand Down
119 changes: 119 additions & 0 deletions Tests/HummingbirdCoreTests/CoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,125 @@ final class HummingBirdCoreTests: XCTestCase {
await serviceGroup.triggerGracefulShutdown()
}
}

#if compiler(>=6.0)
/// Test running withInboundCloseHandler with closing input
func testWithCloseInboundHandlerWithoutClose() async throws {
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
var bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
do {
try await request.body.consumeWithInboundCloseHandler { body in
try await bodyWriter.write(body)
} onInboundClosed: {
}
try await bodyWriter.finish(nil)
} catch {
throw error
}
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
let response = try await client.post("/", body: ByteBuffer(string: "Hello"))
let body = try XCTUnwrap(response.body)
XCTAssertEqual(String(buffer: body), "Hello")
}
}

/// Test running withInboundCloseHandler
func testWithCloseInboundHandler() async throws {
let handlerPromise = Promise<Void>()
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
await handlerPromise.complete(())
var bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
let finished = ManagedAtomic(false)
try await request.body.consumeWithInboundCloseHandler { body in
let body = try await body.collect(upTo: .max)
for _ in 0..<200 {
do {
if finished.load(ordering: .relaxed) {
break
}
try await Task.sleep(for: .milliseconds(300))
try await bodyWriter.write(body)
} catch {
throw error
}
}
} onInboundClosed: {
finished.store(true, ordering: .relaxed)
}
try await bodyWriter.finish(nil)
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
try await client.executeAndDontWaitForResponse(.init("/", method: .get))
await handlerPromise.wait()
try await client.close()
}
}

/// Test running cancel on inbound close without an inbound close
func testCancelOnCloseInboundWithoutClose() async throws {
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
var bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
try await request.body.consumeWithCancellationOnInboundClose { body in
try await bodyWriter.write(body)
}
try await bodyWriter.finish(nil)
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
let response = try await client.post("/", body: ByteBuffer(string: "Hello"))
let body = try XCTUnwrap(response.body)
XCTAssertEqual(String(buffer: body), "Hello")
}
}

/// Test running cancel on inbound close actually cancels on inbound closure
func testCancelOnCloseInbound() async throws {
let handlerPromise = Promise<Void>()
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
await handlerPromise.complete(())
var bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
try await request.body.consumeWithCancellationOnInboundClose { body in
let body = try await body.collect(upTo: .max)
for _ in 0..<50 {
do {
try Task.checkCancellation()
try await Task.sleep(for: .seconds(1))
try await bodyWriter.write(body)
} catch {
throw error
}
}
XCTFail("Should not reach here as the handler should have been cancelled")
}
try await bodyWriter.finish(nil)
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
try await client.executeAndDontWaitForResponse(.init("/", method: .get))
await handlerPromise.wait()
try await client.close()
}
}
#endif // compiler(>=6.0)
}

struct DelayAsyncSequence<CoreSequence: AsyncSequence>: AsyncSequence {
Expand Down
Loading

0 comments on commit 7f08882

Please sign in to comment.