diff --git a/Sources/HummingbirdCore/Request/Request+inboundClose.swift b/Sources/HummingbirdCore/Request/Request+inboundClose.swift new file mode 100644 index 00000000..318962d4 --- /dev/null +++ b/Sources/HummingbirdCore/Request/Request+inboundClose.swift @@ -0,0 +1,201 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTPTypes + +#if compiler(>=6.0) +@available(macOS 15, iOS 18, tvOS 18, *) +extension Request { + /// Run provided closure but cancel it if the inbound request part stream is closed. + /// + /// For `cancelOnInboundClose` to work you need to enable it in the HTTP channel configuration + /// using ``HTTP1Channel/Configuration/supportCancelOnInboundClosure``. + /// + /// - Parameter process: closure to run + /// - Returns: Return value of closure + public func cancelOnInboundClose(_ operation: sending @escaping (Request) async throws -> Value) async throws -> Value { + guard let iterationState = self.iterationState else { return try await operation(self) } + let iterator: UnsafeTransfer.AsyncIterator>? = + switch self.body._backing { + case .nioAsyncChannelRequestBody(let iterator): + iterator.underlyingIterator + default: + nil + } + let (stream, source) = RequestBody.makeStream() + var request = self + request.body = stream + let newRequest = request + return try await iterationState.cancelOnIteratorFinished(iterator: iterator?.wrappedValue, source: source) { + try await operation(newRequest) + } + } + + /// Run provided closure but cancel it if the inbound request part stream is closed. + /// + /// For `withInboundCloseHandler` to work you need to enable it in the HTTP channel configuration + /// using ``HTTP1Channel/Configuration/supportCancelOnInboundClosure``. + /// + /// - Parameter process: closure to run + /// - Returns: Return value of closure + public func withInboundCloseHandler( + _ operation: sending @escaping (Request) async throws -> Value, + onInboundClosed: @Sendable @escaping () -> Void + ) async throws -> Value { + guard let iterationState = self.iterationState else { return try await operation(self) } + let iterator: UnsafeTransfer.AsyncIterator>? = + switch self.body._backing { + case .nioAsyncChannelRequestBody(let iterator): + iterator.underlyingIterator + default: + nil + } + let (stream, source) = RequestBody.makeStream() + var request = self + request.body = stream + let newRequest = request + return try await iterationState.withInboundCloseHandler( + iterator: iterator?.wrappedValue, + source: source, + operation: { + try await operation(newRequest) + }, + onInboundClosed: onInboundClosed + ) + } +} +#endif // compiler(>=6.0) + +/// Request iteration state +@usableFromInline +package actor RequestIterationState: Sendable { + fileprivate enum CancelOnInboundGroupType { + case value(Value) + case done + } + @usableFromInline + package enum State: Sendable { + case idle + case processing + case nextHead(HTTPRequest) + case closed + } + @usableFromInline + var state: State + + init() { + self.state = .idle + } + + enum IterateResult { + case inboundClosed + case nextRequestReady + } + + #if compiler(>=6.0) + @available(macOS 15, iOS 18, tvOS 18, *) + func iterate( + iterator: NIOAsyncChannelInboundStream.AsyncIterator, + source: RequestBody.Source + ) async throws -> IterateResult { + var iterator = iterator + while let part = try await iterator.next(isolation: self) { + switch part { + case .head(let head): + self.state = .nextHead(head) + return .nextRequestReady + case .body(let buffer): + try await source.yield(buffer) + case .end: + source.finish() + } + } + return .inboundClosed + } + + @available(macOS 15, iOS 18, tvOS 18, *) + func cancelOnIteratorFinished( + iterator: NIOAsyncChannelInboundStream.AsyncIterator?, + source: RequestBody.Source, + operation: sending @escaping () async throws -> Value + ) async throws -> Value { + switch (self.state, iterator) { + case (.idle, .some(let asyncIterator)): + self.state = .processing + let unsafeIterator = UnsafeTransfer(asyncIterator) + let unsafeOperation = UnsafeTransfer(operation) + return try await withThrowingTaskGroup(of: CancelOnInboundGroupType.self) { group in + group.addTask { + guard try await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .nextRequestReady else { + throw CancellationError() + } + return .done + } + group.addTask { + try await .value(unsafeOperation.wrappedValue()) + } + do { + while let result = try await group.next() { + if case .value(let value) = result { + return value + } + } + } catch { + self.state = .closed + throw error + } + preconditionFailure("Cannot reach here") + } + case (.idle, .none), (.processing, _), (.nextHead, _): + return try await operation() + + case (.closed, _): + throw CancellationError() + } + } + + @available(macOS 15, iOS 18, tvOS 18, *) + func withInboundCloseHandler( + iterator: NIOAsyncChannelInboundStream.AsyncIterator?, + source: RequestBody.Source, + operation: sending @escaping () async throws -> Value, + onInboundClosed: @Sendable @escaping () -> Void + ) async throws -> Value { + switch (self.state, iterator) { + case (.idle, .some(let asyncIterator)): + self.state = .processing + let unsafeIterator = UnsafeTransfer(asyncIterator) + let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed) + return try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + if try await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .inboundClosed { + unsafeOnInboundClosed.wrappedValue() + } + } + let value = try await operation() + group.cancelAll() + return value + } + case (.idle, .none), (.processing, _), (.nextHead, _): + return try await operation() + + case (.closed, _): + throw CancellationError() + } + } + #endif // compiler(>=6.0) +} diff --git a/Sources/HummingbirdCore/Request/Request.swift b/Sources/HummingbirdCore/Request/Request.swift index 0cb77f94..253dfd99 100644 --- a/Sources/HummingbirdCore/Request/Request.swift +++ b/Sources/HummingbirdCore/Request/Request.swift @@ -2,7 +2,7 @@ // // This source file is part of the Hummingbird server framework project // -// Copyright (c) 2021-2022 the Hummingbird authors +// Copyright (c) 2021-2024 the Hummingbird authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information @@ -13,7 +13,9 @@ //===----------------------------------------------------------------------===// import HTTPTypes +import NIOConcurrencyHelpers import NIOCore +import NIOHTTPTypes /// Holds all the values required to process a request public struct Request: Sendable { @@ -32,6 +34,9 @@ public struct Request: Sendable { @inlinable public var headers: HTTPFields { self.head.headerFields } + @usableFromInline + let iterationState: RequestIterationState? + // MARK: Initialization /// Create new Request @@ -45,6 +50,22 @@ public struct Request: Sendable { self.uri = .init(head.path ?? "") self.head = head self.body = body + self.iterationState = nil + } + + /// Create new Request + /// - Parameters: + /// - head: HTTP head + /// - bodyIterator: HTTP request part stream + package init( + head: HTTPRequest, + bodyIterator: NIOAsyncChannelInboundStream.AsyncIterator, + supportCancelOnInboundClosure: Bool + ) { + self.uri = .init(head.path ?? "") + self.head = head + self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator)) + self.iterationState = supportCancelOnInboundClosure ? .init() : nil } /// Collapse body into one ByteBuffer. @@ -60,6 +81,11 @@ public struct Request: Sendable { self.body = .init(buffer: byteBuffer) return byteBuffer } + + @inlinable + package func getState() async -> RequestIterationState.State? { + await self.iterationState?.state + } } extension Request: CustomStringConvertible { diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index 5a0600b9..c49bf13e 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -125,7 +125,7 @@ extension RequestBody { /// Delegate for NIOThrowingAsyncSequenceProducer @usableFromInline - final class Delegate: NIOAsyncSequenceProducerDelegate { + final class Delegate: NIOAsyncSequenceProducerDelegate, Sendable { let checkedContinuations: NIOLockedValueBox>> @usableFromInline @@ -162,13 +162,13 @@ extension RequestBody { } /// A source used for driving a ``RequestBody`` stream. - public final class Source { + public final class Source: Sendable { @usableFromInline let source: Producer.Source @usableFromInline let delegate: Delegate @usableFromInline - var waitForProduceMore: Bool + let waitForProduceMore: NIOLockedValueBox @usableFromInline init(source: Producer.Source, delegate: Delegate) { @@ -187,13 +187,13 @@ extension RequestBody { public func yield(_ element: ByteBuffer) async throws { // if previous call indicated we should stop producing wait until the delegate // says we can start producing again - if self.waitForProduceMore { + if self.waitForProduceMore.withLockedValue({ $0 }) { await self.delegate.waitForProduceMore() - self.waitForProduceMore = false + self.waitForProduceMore.withLockedValue { $0 = false } } let result = self.source.yield(element) if result == .stopProducing { - self.waitForProduceMore = true + self.waitForProduceMore.withLockedValue { $0 = true } } } diff --git a/Sources/HummingbirdCore/Response/ResponseWriter.swift b/Sources/HummingbirdCore/Response/ResponseWriter.swift index 1ee6a174..5f913c62 100644 --- a/Sources/HummingbirdCore/Response/ResponseWriter.swift +++ b/Sources/HummingbirdCore/Response/ResponseWriter.swift @@ -17,7 +17,7 @@ import NIOCore import NIOHTTPTypes /// ResponseWriter that writes directly to AsyncChannel -public struct ResponseWriter: ~Copyable { +public struct ResponseWriter: ~Copyable, Sendable { @usableFromInline let outbound: NIOAsyncChannelOutboundWriter diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift b/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift index 4cd666f6..9988677d 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift @@ -28,18 +28,23 @@ public struct HTTP1Channel: ServerChildChannel, HTTPChannelHandler { public var additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] /// Time before closing an idle channel. public var idleTimeout: TimeAmount? + /// Support being able to use ``Request/cancelOnInboundClosure`` + public var supportCancelOnInboundClosure: Bool /// Initialize HTTP1Channel.Configuration /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add to channel pipeline after HTTP part decoding and /// before HTTP request processing /// - idleTimeout: Time before closing an idle channel + /// - supportCancelOnInboundClosure: Support being able to use ``Request/cancelOnInboundClosure`` public init( additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - idleTimeout: TimeAmount? = nil + idleTimeout: TimeAmount? = nil, + supportCancelOnInboundClosure: Bool = false ) { self.additionalChannelHandlers = additionalChannelHandlers self.idleTimeout = idleTimeout + self.supportCancelOnInboundClosure = supportCancelOnInboundClosure } } @@ -106,6 +111,8 @@ public struct HTTP1Channel: ServerChildChannel, HTTPChannelHandler { await handleHTTP(asyncChannel: asyncChannel, logger: logger) } + public var supportCancelOnInboundClosure: Bool { configuration.supportCancelOnInboundClosure } + public let responder: HTTPChannelHandler.Responder public let configuration: Configuration } diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index e61c5149..641a4604 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -22,7 +22,10 @@ import ServiceLifecycle /// Protocol for HTTP channels public protocol HTTPChannelHandler: ServerChildChannel { typealias Responder = @Sendable (Request, consuming ResponseWriter, Channel) async throws -> Void + /// HTTP Request responder var responder: Responder { get } + /// Support being able to use ``Request/cancelOnInboundClosure`` + var supportCancelOnInboundClosure: Bool { get } } /// Internal error thrown when an unexpected HTTP part is received eg we didn't receive @@ -45,9 +48,12 @@ extension HTTPChannelHandler { throw HTTPChannelError.unexpectedHTTPPart(part) } - while true { - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) + readParts: while true { + let request = Request( + head: head, + bodyIterator: iterator, + supportCancelOnInboundClosure: self.supportCancelOnInboundClosure + ) let responseWriter = ResponseWriter(outbound: outbound) do { try await self.responder(request, responseWriter, asyncChannel.channel) @@ -57,7 +63,15 @@ extension HTTPChannelHandler { if request.headers[.connection] == "close" { return } - + switch await request.getState() { + case .nextHead(let newHead): + head = newHead + continue + case .closed: + break readParts + default: + break + } // Flush current request // read until we don't have a body part var part: HTTPRequestPart? @@ -88,4 +102,6 @@ extension HTTPChannelHandler { logger.trace("Failed to read/write to Channel. Error: \(error)") } } + + public var supportCancelOnInboundClosure: Bool { false } } diff --git a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift index 3d266c4f..8188c118 100644 --- a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift +++ b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift @@ -71,8 +71,11 @@ struct HTTP2StreamChannel: ServerChildChannel { guard case .head(let head) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) + let request = Request( + head: head, + bodyIterator: iterator, + supportCancelOnInboundClosure: self.configuration.supportCancelOnInboundClosure + ) let responseWriter = ResponseWriter(outbound: outbound) try await self.responder(request, responseWriter, asyncChannel.channel) } diff --git a/Sources/HummingbirdTesting/TestClient.swift b/Sources/HummingbirdTesting/TestClient.swift index faba6810..b1f66042 100644 --- a/Sources/HummingbirdTesting/TestClient.swift +++ b/Sources/HummingbirdTesting/TestClient.swift @@ -104,7 +104,7 @@ public struct TestClient: Sendable { /// shutdown client public func shutdown() async throws { do { - try await self.close() + try await self.close(mode: .all) } catch TestClient.Error.connectionNotOpen { } catch ChannelError.alreadyClosed {} if case .createNew = self.eventLoopGroupProvider { @@ -163,10 +163,18 @@ public struct TestClient: Sendable { return response } - public func close() async throws { + /// Execute request to server but don't wait for response. + public func executeAndDontWaitForResponse(_ request: TestClient.Request) async throws { + let channel = try await getChannel() + let promise = self.eventLoopGroup.any().makePromise(of: TestClient.Response.self) + let task = HTTPTask(request: self.cleanupRequest(request), responsePromise: promise) + try await channel.writeAndFlush(task) + } + + public func close(mode: CloseMode = .output) async throws { self.channelPromise.completeWith(.failure(TestClient.Error.connectionNotOpen)) let channel = try await getChannel() - return try await channel.close() + return try await channel.close(mode: mode) } public func getChannel() async throws -> Channel { diff --git a/Tests/HummingbirdCoreTests/CoreTests.swift b/Tests/HummingbirdCoreTests/CoreTests.swift index 61c5c3d3..46f438d8 100644 --- a/Tests/HummingbirdCoreTests/CoreTests.swift +++ b/Tests/HummingbirdCoreTests/CoreTests.swift @@ -553,6 +553,105 @@ final class HummingBirdCoreTests: XCTestCase { await serviceGroup.triggerGracefulShutdown() } } + + #if compiler(>=6.0) + /// Test running cancel on inbound close without an inbound close + @available(macOS 15, iOS 18, tvOS 18, *) + func testCancelOnCloseInboundWithoutClose() async throws { + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + try await request.cancelOnInboundClose { request in + try await bodyWriter.write(request.body) + try await bodyWriter.finish(nil) + } + }, + httpChannelSetup: .http1(configuration: .init(supportCancelOnInboundClosure: true)), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + let response = try await client.get("/") + XCTAssertNil(response.body) + let response2 = try await client.post("/", body: ByteBuffer(string: "Hello")) + let body2 = try XCTUnwrap(response2.body) + XCTAssertEqual(String(buffer: body2), "Hello") + } + } + + /// Test running cancel on inbound close actually cancels on inbound closure + @available(macOS 15, iOS 18, tvOS 18, *) + func testCancelOnCloseInbound() async throws { + let handlerPromise = Promise() + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + await handlerPromise.complete(()) + let bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + try await request.cancelOnInboundClose { request in + var bodyWriter2 = bodyWriter + let body = try await request.body.collect(upTo: .max) + for _ in 0..<200 { + do { + try Task.checkCancellation() + try await Task.sleep(for: .seconds(1)) + try await bodyWriter2.write(body) + } catch { + throw error + } + } + try await Task.sleep(for: .seconds(60)) + } + }, + httpChannelSetup: .http1(configuration: .init(supportCancelOnInboundClosure: true)), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + try await client.executeAndDontWaitForResponse(.init("/", method: .get)) + await handlerPromise.wait() + try await client.close() + } + } + + /// Test running withInboundCloseHandler + @available(macOS 15, iOS 18, tvOS 18, *) + func testWithCloseInboundHandler() async throws { + let handlerPromise = Promise() + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + await handlerPromise.complete(()) + let bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + let finished = ManagedAtomic(false) + try await request.withInboundCloseHandler { request in + var bodyWriter2 = bodyWriter + let body = try await request.body.collect(upTo: .max) + for _ in 0..<200 { + do { + if finished.load(ordering: .relaxed) { + break + } + try await Task.sleep(for: .milliseconds(300)) + try await bodyWriter2.write(body) + } catch { + throw error + } + } + try await bodyWriter2.finish(nil) + } onInboundClosed: { + finished.store(true, ordering: .relaxed) + } + }, + httpChannelSetup: .http1(configuration: .init(supportCancelOnInboundClosure: true)), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + try await client.executeAndDontWaitForResponse(.init("/", method: .get)) + await handlerPromise.wait() + try await client.close() + } + } + #endif // compiler(>=6.0) } struct DelayAsyncSequence: AsyncSequence { diff --git a/Tests/HummingbirdCoreTests/TestUtils.swift b/Tests/HummingbirdCoreTests/TestUtils.swift index 57add19a..f50e52ac 100644 --- a/Tests/HummingbirdCoreTests/TestUtils.swift +++ b/Tests/HummingbirdCoreTests/TestUtils.swift @@ -80,22 +80,38 @@ public func testServer( clientConfiguration: TestClient.Configuration = .init(), test: @escaping @Sendable (TestClient) async throws -> Value ) async throws -> Value { - try await testServer( - responder: responder, - httpChannelSetup: httpChannelSetup, - configuration: configuration, - eventLoopGroup: eventLoopGroup, - logger: logger - ) { port in + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let server = try httpChannelSetup.buildServer( + configuration: configuration, + eventLoopGroup: eventLoopGroup, + logger: logger, + responder: responder, + onServerRunning: { await promise.complete($0.localAddress!.port!) } + ) + let serviceGroup = ServiceGroup( + configuration: .init( + services: [server], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: logger + ) + ) + + group.addTask { + try await serviceGroup.run() + } + let port = await promise.wait() let client = TestClient( host: "localhost", port: port, configuration: clientConfiguration, eventLoopGroupProvider: .createNew ) + print("Client connecting to port \(port)") client.connect() let value = try await test(client) - try await client.shutdown() + try? await client.shutdown() + await serviceGroup.triggerGracefulShutdown() return value } }