diff --git a/Sources/HummingbirdCore/Request/Request.swift b/Sources/HummingbirdCore/Request/Request.swift index 0cb77f94..9b27df89 100644 --- a/Sources/HummingbirdCore/Request/Request.swift +++ b/Sources/HummingbirdCore/Request/Request.swift @@ -2,7 +2,7 @@ // // This source file is part of the Hummingbird server framework project // -// Copyright (c) 2021-2022 the Hummingbird authors +// Copyright (c) 2021-2024 the Hummingbird authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information @@ -14,6 +14,7 @@ import HTTPTypes import NIOCore +import NIOHTTPTypes /// Holds all the values required to process a request public struct Request: Sendable { @@ -47,6 +48,19 @@ public struct Request: Sendable { self.body = body } + /// Create new Request + /// - Parameters: + /// - head: HTTP head + /// - bodyIterator: HTTP request part stream + package init( + head: HTTPRequest, + bodyIterator: NIOAsyncChannelInboundStream.AsyncIterator + ) { + self.uri = .init(head.path ?? "") + self.head = head + self.body = .init(nioAsyncChannelInbound: .init(iterator: bodyIterator)) + } + /// Collapse body into one ByteBuffer. /// /// This will store the collated ByteBuffer back into the request so is a mutating method. If diff --git a/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift b/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift new file mode 100644 index 00000000..dab3220e --- /dev/null +++ b/Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift @@ -0,0 +1,144 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTPTypes + +#if compiler(>=6.0) +extension RequestBody { + /// Run provided closure but cancel it if the inbound request part stream is closed. + /// + /// This function is designed for use with long running requests like server sent events. It assumes you + /// are not going to be using the request body after calling as it consumes the request body, it also assumes + /// you havent edited the request body prior to calling this function. + /// + /// If the response finishes the connection will be closed. + /// + /// - Parameters + /// - isolation: The isolation of the method. Defaults to the isolation of the caller. + /// - operation: The actual operation + /// = onInboundClose: handler invoked when inbound is closed + /// - Returns: Return value of operation + public func consumeWithInboundCloseHandler( + isolation: isolated (any Actor)? = #isolation, + _ operation: (RequestBody) async throws -> Value, + onInboundClosed: @Sendable @escaping () -> Void + ) async throws -> Value { + let iterator: UnsafeTransfer.AsyncIterator> = + switch self._backing { + case .nioAsyncChannelRequestBody(let iterator): + iterator.underlyingIterator + default: + preconditionFailure("Cannot run consumeWithInboundCloseHandler on edited request body") + } + let (requestBody, source) = RequestBody.makeStream() + return try await withInboundCloseHandler( + iterator: iterator.wrappedValue, + source: source, + operation: { + try await operation(requestBody) + }, + onInboundClosed: onInboundClosed + ) + } + + /// Run provided closure but cancel it if the inbound request part stream is closed. + /// + /// This function is designed for use with long running requests like server sent events. It assumes you + /// are not going to be using the request body after calling as it consumes the request body, it also assumes + /// you havent edited the request body prior to calling this function. + /// + /// If the response finishes the connection will be closed. + /// + /// - Parameters + /// - isolation: The isolation of the method. Defaults to the isolation of the caller. + /// - operation: The actual operation to run + /// - Returns: Return value of operation + public func consumeWithCancellationOnInboundClose( + _ operation: sending @escaping (RequestBody) async throws -> Value + ) async throws -> Value { + let (barrier, source) = AsyncStream.makeStream() + return try await consumeWithInboundCloseHandler { body in + try await withThrowingTaskGroup(of: Value.self) { group in + let unsafeOperation = UnsafeTransfer(operation) + group.addTask { + var iterator = barrier.makeAsyncIterator() + _ = await iterator.next() + throw CancellationError() + } + group.addTask { + try await unsafeOperation.wrappedValue(body) + } + if case .some(let value) = try await group.next() { + source.finish() + return value + } + group.cancelAll() + throw CancellationError() + } + } onInboundClosed: { + source.finish() + } + } + + fileprivate func withInboundCloseHandler( + isolation: isolated (any Actor)? = #isolation, + iterator: NIOAsyncChannelInboundStream.AsyncIterator, + source: RequestBody.Source, + operation: () async throws -> Value, + onInboundClosed: @Sendable @escaping () -> Void + ) async throws -> Value { + let unsafeIterator = UnsafeTransfer(iterator) + let unsafeOnInboundClosed = UnsafeTransfer(onInboundClosed) + let value = try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + do { + if try await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .inboundClosed { + unsafeOnInboundClosed.wrappedValue() + } + } catch is CancellationError {} + } + let value = try await operation() + group.cancelAll() + return value + } + return value + } + + fileprivate enum IterateResult { + case inboundClosed + case nextRequestReady + } + + fileprivate func iterate( + iterator: NIOAsyncChannelInboundStream.AsyncIterator, + source: RequestBody.Source + ) async throws -> IterateResult { + var iterator = iterator + while let part = try await iterator.next() { + switch part { + case .head: + return .nextRequestReady + case .body(let buffer): + try await source.yield(buffer) + case .end: + source.finish() + } + } + return .inboundClosed + } +} +#endif // compiler(>=6.0) diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index e61c5149..7a7782ab 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -22,6 +22,7 @@ import ServiceLifecycle /// Protocol for HTTP channels public protocol HTTPChannelHandler: ServerChildChannel { typealias Responder = @Sendable (Request, consuming ResponseWriter, Channel) async throws -> Void + /// HTTP Request responder var responder: Responder { get } } @@ -46,8 +47,10 @@ extension HTTPChannelHandler { } while true { - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) + let request = Request( + head: head, + bodyIterator: iterator + ) let responseWriter = ResponseWriter(outbound: outbound) do { try await self.responder(request, responseWriter, asyncChannel.channel) diff --git a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift index 3d266c4f..4bf775f6 100644 --- a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift +++ b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift @@ -71,8 +71,10 @@ struct HTTP2StreamChannel: ServerChildChannel { guard case .head(let head) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream)) + let request = Request( + head: head, + bodyIterator: iterator + ) let responseWriter = ResponseWriter(outbound: outbound) try await self.responder(request, responseWriter, asyncChannel.channel) } diff --git a/Sources/HummingbirdTesting/TestClient.swift b/Sources/HummingbirdTesting/TestClient.swift index faba6810..b1f66042 100644 --- a/Sources/HummingbirdTesting/TestClient.swift +++ b/Sources/HummingbirdTesting/TestClient.swift @@ -104,7 +104,7 @@ public struct TestClient: Sendable { /// shutdown client public func shutdown() async throws { do { - try await self.close() + try await self.close(mode: .all) } catch TestClient.Error.connectionNotOpen { } catch ChannelError.alreadyClosed {} if case .createNew = self.eventLoopGroupProvider { @@ -163,10 +163,18 @@ public struct TestClient: Sendable { return response } - public func close() async throws { + /// Execute request to server but don't wait for response. + public func executeAndDontWaitForResponse(_ request: TestClient.Request) async throws { + let channel = try await getChannel() + let promise = self.eventLoopGroup.any().makePromise(of: TestClient.Response.self) + let task = HTTPTask(request: self.cleanupRequest(request), responsePromise: promise) + try await channel.writeAndFlush(task) + } + + public func close(mode: CloseMode = .output) async throws { self.channelPromise.completeWith(.failure(TestClient.Error.connectionNotOpen)) let channel = try await getChannel() - return try await channel.close() + return try await channel.close(mode: mode) } public func getChannel() async throws -> Channel { diff --git a/Tests/HummingbirdCoreTests/CoreTests.swift b/Tests/HummingbirdCoreTests/CoreTests.swift index 61c5c3d3..3515a4d5 100644 --- a/Tests/HummingbirdCoreTests/CoreTests.swift +++ b/Tests/HummingbirdCoreTests/CoreTests.swift @@ -553,6 +553,125 @@ final class HummingBirdCoreTests: XCTestCase { await serviceGroup.triggerGracefulShutdown() } } + + #if compiler(>=6.0) + /// Test running withInboundCloseHandler with closing input + func testWithCloseInboundHandlerWithoutClose() async throws { + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + do { + try await request.body.consumeWithInboundCloseHandler { body in + try await bodyWriter.write(body) + } onInboundClosed: { + } + try await bodyWriter.finish(nil) + } catch { + throw error + } + }, + httpChannelSetup: .http1(), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + let response = try await client.post("/", body: ByteBuffer(string: "Hello")) + let body = try XCTUnwrap(response.body) + XCTAssertEqual(String(buffer: body), "Hello") + } + } + + /// Test running withInboundCloseHandler + func testWithCloseInboundHandler() async throws { + let handlerPromise = Promise() + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + await handlerPromise.complete(()) + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + let finished = ManagedAtomic(false) + try await request.body.consumeWithInboundCloseHandler { body in + let body = try await body.collect(upTo: .max) + for _ in 0..<200 { + do { + if finished.load(ordering: .relaxed) { + break + } + try await Task.sleep(for: .milliseconds(300)) + try await bodyWriter.write(body) + } catch { + throw error + } + } + } onInboundClosed: { + finished.store(true, ordering: .relaxed) + } + try await bodyWriter.finish(nil) + }, + httpChannelSetup: .http1(), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + try await client.executeAndDontWaitForResponse(.init("/", method: .get)) + await handlerPromise.wait() + try await client.close() + } + } + + /// Test running cancel on inbound close without an inbound close + func testCancelOnCloseInboundWithoutClose() async throws { + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + try await request.body.consumeWithCancellationOnInboundClose { body in + try await bodyWriter.write(body) + } + try await bodyWriter.finish(nil) + }, + httpChannelSetup: .http1(), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + let response = try await client.post("/", body: ByteBuffer(string: "Hello")) + let body = try XCTUnwrap(response.body) + XCTAssertEqual(String(buffer: body), "Hello") + } + } + + /// Test running cancel on inbound close actually cancels on inbound closure + func testCancelOnCloseInbound() async throws { + let handlerPromise = Promise() + try await testServer( + responder: { (request, responseWriter: consuming ResponseWriter, _) in + await handlerPromise.complete(()) + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + try await request.body.consumeWithCancellationOnInboundClose { body in + let body = try await body.collect(upTo: .max) + for _ in 0..<50 { + do { + try Task.checkCancellation() + try await Task.sleep(for: .seconds(1)) + try await bodyWriter.write(body) + } catch { + throw error + } + } + XCTFail("Should not reach here as the handler should have been cancelled") + } + try await bodyWriter.finish(nil) + }, + httpChannelSetup: .http1(), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + try await client.executeAndDontWaitForResponse(.init("/", method: .get)) + await handlerPromise.wait() + try await client.close() + } + } + #endif // compiler(>=6.0) } struct DelayAsyncSequence: AsyncSequence { diff --git a/Tests/HummingbirdCoreTests/TestUtils.swift b/Tests/HummingbirdCoreTests/TestUtils.swift index 57add19a..138caf2b 100644 --- a/Tests/HummingbirdCoreTests/TestUtils.swift +++ b/Tests/HummingbirdCoreTests/TestUtils.swift @@ -80,13 +80,27 @@ public func testServer( clientConfiguration: TestClient.Configuration = .init(), test: @escaping @Sendable (TestClient) async throws -> Value ) async throws -> Value { - try await testServer( - responder: responder, - httpChannelSetup: httpChannelSetup, - configuration: configuration, - eventLoopGroup: eventLoopGroup, - logger: logger - ) { port in + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let server = try httpChannelSetup.buildServer( + configuration: configuration, + eventLoopGroup: eventLoopGroup, + logger: logger, + responder: responder, + onServerRunning: { await promise.complete($0.localAddress!.port!) } + ) + let serviceGroup = ServiceGroup( + configuration: .init( + services: [server], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: logger + ) + ) + + group.addTask { + try await serviceGroup.run() + } + let port = await promise.wait() let client = TestClient( host: "localhost", port: port, @@ -95,7 +109,8 @@ public func testServer( ) client.connect() let value = try await test(client) - try await client.shutdown() + try? await client.shutdown() + await serviceGroup.triggerGracefulShutdown() return value } } diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index 3063a408..a8c26a91 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -866,6 +866,34 @@ final class ApplicationTests: XCTestCase { } } } + + #if compiler(>=6.0) + /// Test consumeWithInboundCloseHandler + func testConsumeWithInboundHandler() async throws { + let router = Router() + router.post("streaming") { request, context -> Response in + Response( + status: .ok, + body: .init { writer in + try await request.body.consumeWithInboundCloseHandler { body in + try await writer.write(body) + } onInboundClosed: { + } + try await writer.finish(nil) + } + ) + } + let app = Application(responder: router.buildResponder()) + + try await app.test(.live) { client in + let buffer = Self.randomBuffer(size: 640_001) + try await client.execute(uri: "/streaming", method: .post, body: buffer) { response in + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.body, buffer) + } + } + } + #endif } /// HTTPField used during tests