From 1b226daf0f8f0892b336faa54ff14b1d01bfde3d Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 16 May 2024 16:00:49 +0100 Subject: [PATCH 1/5] Revert to using quiescing helper --- .../Server/HTTP/HTTPChannelHandler.swift | 98 ++++++++----------- .../Server/HTTPUserEventHandler.swift | 11 ++- Sources/HummingbirdCore/Server/Server.swift | 50 ++++++++-- .../LiveTestFramework.swift | 4 +- Tests/HummingbirdCoreTests/HTTP2Tests.swift | 4 +- 5 files changed, 94 insertions(+), 73 deletions(-) diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index afdb5a3ba..ad9d779d9 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -32,75 +32,55 @@ enum HTTPChannelError: Error { case unexpectedHTTPPart(HTTPRequestPart) } -enum HTTPState: Int, Sendable { - case idle - case processing - case cancelled -} - extension HTTPChannelHandler { public func handleHTTP(asyncChannel: NIOAsyncChannel, logger: Logger) async { - let processingRequest = NIOLockedValueBox(HTTPState.idle) do { try await withTaskCancellationHandler { - try await withGracefulShutdownHandler { - try await asyncChannel.executeThenClose { inbound, outbound in - let responseWriter = HTTPServerBodyWriter(outbound: outbound) - var iterator = inbound.makeAsyncIterator() + try await asyncChannel.executeThenClose { inbound, outbound in + let responseWriter = HTTPServerBodyWriter(outbound: outbound) + var iterator = inbound.makeAsyncIterator() + + // read first part, verify it is a head + guard let part = try await iterator.next() else { return } + guard case .head(var head) = part else { + throw HTTPChannelError.unexpectedHTTPPart(part) + } - // read first part, verify it is a head - guard let part = try await iterator.next() else { return } - guard case .head(var head) = part else { - throw HTTPChannelError.unexpectedHTTPPart(part) + while true { + let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) + let request = Request(head: head, body: .init(asyncSequence: bodyStream)) + let response = await self.responder(request, asyncChannel.channel) + do { + try await outbound.write(.head(response.head)) + let tailHeaders = try await response.body.write(responseWriter) + try await outbound.write(.end(tailHeaders)) + } catch { + throw error + } + if request.headers[.connection] == "close" { + return } + // Flush current request + // read until we don't have a body part + var part: HTTPRequestPart? while true { - // set to processing unless it is cancelled then exit - guard processingRequest.exchange(.processing) == .idle else { break } - - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(asyncSequence: bodyStream)) - let response: Response = await self.responder(request, asyncChannel.channel) - do { - try await outbound.write(.head(response.head)) - let tailHeaders = try await response.body.write(responseWriter) - try await outbound.write(.end(tailHeaders)) - } catch { - throw error - } - if request.headers[.connection] == "close" { - return - } - // set to idle unless it is cancelled then exit - guard processingRequest.exchange(.idle) == .processing else { break } - - // Flush current request - // read until we don't have a body part - var part: HTTPRequestPart? - while true { - part = try await iterator.next() - guard case .body = part else { break } - } - // if we have an end then read the next part - if case .end = part { - part = try await iterator.next() - } - - // if part is nil break out of loop - guard let part else { - break - } + part = try await iterator.next() + guard case .body = part else { break } + } + // if we have an end then read the next part + if case .end = part { + part = try await iterator.next() + } - // part should be a head, if not throw error - guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } - head = newHead + // if part is nil break out of loop + guard let part else { + break } - } - } onGracefulShutdown: { - // set to cancelled - if processingRequest.exchange(.cancelled) == .idle { - // only close the channel input if it is idle - asyncChannel.channel.close(mode: .input, promise: nil) + + // part should be a head, if not throw error + guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } + head = newHead } } } onCancel: { diff --git a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift index 3b38c857b..d7e1fbbf6 100644 --- a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift @@ -35,7 +35,7 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH let part = unwrapOutboundIn(data) if case .end = part { self.requestsInProgress -= 1 - context.write(data, promise: promise) + context.writeAndFlush(data, promise: promise) if self.closeAfterResponseWritten { context.close(promise: nil) self.closeAfterResponseWritten = false @@ -61,6 +61,15 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { + case is ChannelShouldQuiesceEvent: + // we received a quiesce event. If we have any requests in progress we should + // wait for them to finish + if self.requestsInProgress > 0 { + self.closeAfterResponseWritten = true + } else { + context.close(promise: nil) + } + case IdleStateHandler.IdleStateEvent.read: // if we get an idle read event and we haven't completed reading the request // close the connection diff --git a/Sources/HummingbirdCore/Server/Server.swift b/Sources/HummingbirdCore/Server/Server.swift index 51dcab30c..078f66432 100644 --- a/Sources/HummingbirdCore/Server/Server.swift +++ b/Sources/HummingbirdCore/Server/Server.swift @@ -33,7 +33,10 @@ public actor Server: Service { onServerRunning: (@Sendable (Channel) async -> Void)? ) case starting - case running(asyncChannel: AsyncServerChannel) + case running( + asyncChannel: AsyncServerChannel, + quiescingHelper: ServerQuiescingHelper + ) case shuttingDown(shutdownPromise: EventLoopPromise) case shutdown @@ -96,7 +99,7 @@ public actor Server: Service { self.state = .starting do { - let asyncChannel = try await self.makeServer( + let (asyncChannel, quiescingHelper) = try await self.makeServer( childChannelSetup: childChannelSetup, configuration: configuration ) @@ -107,7 +110,7 @@ public actor Server: Service { fatalError("We should only be running once") case .starting: - self.state = .running(asyncChannel: asyncChannel) + self.state = .running(asyncChannel: asyncChannel, quiescingHelper: quiescingHelper) await withGracefulShutdownHandler { await onServerRunning?(asyncChannel.channel) @@ -138,13 +141,14 @@ public actor Server: Service { } case .shuttingDown, .shutdown: + self.logger.info("Shutting down") try await asyncChannel.channel.close() } } catch { self.state = .shutdown throw error } - self.state = .shutdown + case .starting, .running: fatalError("Run should only be called once") @@ -162,10 +166,20 @@ public actor Server: Service { case .initial, .starting: self.state = .shutdown - case .running(let channel): + case .running(let channel, let quiescingHelper): let shutdownPromise = channel.channel.eventLoop.makePromise(of: Void.self) - channel.channel.close(promise: shutdownPromise) self.state = .shuttingDown(shutdownPromise: shutdownPromise) + quiescingHelper.initiateShutdown(promise: shutdownPromise) + try await shutdownPromise.futureResult.get() + + // We need to check the state here again since we just awaited above + switch self.state { + case .initial, .starting, .running, .shutdown: + fatalError("Unexpected state \(self.state)") + + case .shuttingDown: + self.state = .shutdown + } case .shuttingDown(let shutdownPromise): // We are just going to queue up behind the current graceful shutdown @@ -179,8 +193,8 @@ public actor Server: Service { /// Start server /// - Parameter responder: Object that provides responses to requests sent to the server /// - Returns: EventLoopFuture that is fulfilled when server has started - nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> AsyncServerChannel { - let bootstrap: ServerBootstrapProtocol + nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> (AsyncServerChannel, ServerQuiescingHelper) { + var bootstrap: ServerBootstrapProtocol #if canImport(Network) if let tsBootstrap = self.createTSBootstrap(configuration: configuration) { bootstrap = tsBootstrap @@ -199,6 +213,11 @@ public actor Server: Service { ) #endif + let quiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) + bootstrap = bootstrap.serverChannelInitializer { channel in + channel.pipeline.addHandler(quiescingHelper.makeServerChannelHandler(channel: channel)) + } + do { switch configuration.address.value { case .hostname(let host, let port): @@ -213,7 +232,7 @@ public actor Server: Service { ) } self.logger.info("Server started and listening on \(host):\(asyncChannel.channel.localAddress?.port ?? port)") - return asyncChannel + return (asyncChannel, quiescingHelper) case .unixDomainSocket(let path): let asyncChannel = try await bootstrap.bind( @@ -227,7 +246,7 @@ public actor Server: Service { ) } self.logger.info("Server started and listening on socket path \(path)") - return asyncChannel + return (asyncChannel, quiescingHelper) } } catch { // should we close the channel here @@ -271,6 +290,17 @@ public actor Server: Service { /// Protocol for bootstrap. protocol ServerBootstrapProtocol { + /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add + /// `ChannelHandler`s to the `ChannelPipeline`. + /// + /// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages. + /// + /// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`. + /// + /// - parameters: + /// - initializer: A closure that initializes the provided `Channel`. + func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self + func bind( host: String, port: Int, diff --git a/Sources/HummingbirdTesting/LiveTestFramework.swift b/Sources/HummingbirdTesting/LiveTestFramework.swift index a0d80acef..8bc6b853d 100644 --- a/Sources/HummingbirdTesting/LiveTestFramework.swift +++ b/Sources/HummingbirdTesting/LiveTestFramework.swift @@ -72,12 +72,12 @@ final class LiveTestFramework: ApplicationTestFramewor client.connect() do { let value = try await test(Client(client: client)) - await serviceGroup.triggerGracefulShutdown() try await client.shutdown() + await serviceGroup.triggerGracefulShutdown() return value } catch { - await serviceGroup.triggerGracefulShutdown() try await client.shutdown() + await serviceGroup.triggerGracefulShutdown() throw error } } diff --git a/Tests/HummingbirdCoreTests/HTTP2Tests.swift b/Tests/HummingbirdCoreTests/HTTP2Tests.swift index 333317dfa..bd6df2c3d 100644 --- a/Tests/HummingbirdCoreTests/HTTP2Tests.swift +++ b/Tests/HummingbirdCoreTests/HTTP2Tests.swift @@ -28,6 +28,8 @@ class HummingBirdHTTP2Tests: XCTestCase { func testConnect() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace try await testServer( responder: { _, _ in .init(status: .ok) @@ -35,7 +37,7 @@ class HummingBirdHTTP2Tests: XCTestCase { httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), configuration: .init(address: .hostname(port: 0), serverName: testServerName), eventLoopGroup: eventLoopGroup, - logger: Logger(label: "Hummingbird") + logger: logger ) { port in var tlsConfiguration = try getClientTLSConfiguration() // no way to override the SSL server name with AsyncHTTPClient so need to set From b725d1226ddb05144c46844c8452d9d99ac7678b Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 4 Jul 2024 11:32:19 +0100 Subject: [PATCH 2/5] Close on waiting for read to initiate, add better tests --- .../Server/HTTPUserEventHandler.swift | 9 +-- Tests/HummingbirdCoreTests/CoreTests.swift | 55 ++++++++++++++++++- 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift index d7e1fbbf6..d2076d41f 100644 --- a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift @@ -73,18 +73,11 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH case IdleStateHandler.IdleStateEvent.read: // if we get an idle read event and we haven't completed reading the request // close the connection - if self.requestsBeingRead > 0 { + if self.requestsBeingRead > 0 || self.requestsInProgress == 0 { self.logger.trace("Idle read timeout, so close channel") context.close(promise: nil) } - case IdleStateHandler.IdleStateEvent.write: - // if we get an idle write event and are not currently processing a request - if self.requestsInProgress == 0 { - self.logger.trace("Idle write timeout, so close channel") - context.close(mode: .input, promise: nil) - } - default: context.fireUserInboundEventTriggered(event) } diff --git a/Tests/HummingbirdCoreTests/CoreTests.swift b/Tests/HummingbirdCoreTests/CoreTests.swift index 5b515cb59..88305e665 100644 --- a/Tests/HummingbirdCoreTests/CoreTests.swift +++ b/Tests/HummingbirdCoreTests/CoreTests.swift @@ -262,7 +262,7 @@ class HummingBirdCoreTests: XCTestCase { } } - func testReadIdleHandler() async throws { + func testUnfinishedReadIdleHandler() async throws { /// Channel Handler for serializing request header and data final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPRequestPart @@ -304,7 +304,56 @@ class HummingBirdCoreTests: XCTestCase { } } - func testWriteIdleTimeout() async throws { + func testUninitiatedReadIdleHandler() async throws { + /// Channel Handler for serializing request header and data + final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = HTTPRequestPart + typealias InboundOut = HTTPRequestPart + + func channelRead(context: ChannelHandlerContext, data: NIOAny) {} + } + try await testServer( + responder: { request, _ in + do { + _ = try await request.body.collect(upTo: .max) + } catch { + return Response(status: .contentTooLarge) + } + return .init(status: .ok) + }, + httpChannelSetup: .http1(additionalChannelHandlers: [HTTPServerIncompleteRequest(), IdleStateHandler(readTimeout: .seconds(1))]), + configuration: .init(address: .hostname(port: 0)), + eventLoopGroup: Self.eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { client in + try await withTimeout(.seconds(5)) { + do { + _ = try await client.get("/", headers: [.connection: "keep-alive"]) + XCTFail("Should not get here") + } catch TestClient.Error.connectionClosing { + } catch { + XCTFail("Unexpected error: \(error)") + } + } + } + } + + func testLeftOpenReadIdleHandler() async throws { + /// Channel Handler for serializing request header and data + final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = HTTPRequestPart + typealias InboundOut = HTTPRequestPart + var readOneRequest = false + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let part = self.unwrapInboundIn(data) + if !self.readOneRequest { + context.fireChannelRead(data) + } + if case .end = part { + self.readOneRequest = true + } + } + } try await testServer( responder: { request, _ in do { @@ -314,7 +363,7 @@ class HummingBirdCoreTests: XCTestCase { } return .init(status: .ok) }, - httpChannelSetup: .http1(additionalChannelHandlers: [IdleStateHandler(writeTimeout: .seconds(1))]), + httpChannelSetup: .http1(additionalChannelHandlers: [HTTPServerIncompleteRequest(), IdleStateHandler(readTimeout: .seconds(1))]), configuration: .init(address: .hostname(port: 0)), eventLoopGroup: Self.eventLoopGroup, logger: Logger(label: "Hummingbird") From 64033e671397cc507e5ec076141dcfbc55f1d79e Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 4 Jul 2024 11:33:26 +0100 Subject: [PATCH 3/5] comment --- Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift index d2076d41f..3f00280d2 100644 --- a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift @@ -72,7 +72,7 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH case IdleStateHandler.IdleStateEvent.read: // if we get an idle read event and we haven't completed reading the request - // close the connection + // close the connection, or a request hasnt been initiated if self.requestsBeingRead > 0 || self.requestsInProgress == 0 { self.logger.trace("Idle read timeout, so close channel") context.close(promise: nil) From 0df07118b8c70c4e4358143cc564fb141363f7df Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 4 Jul 2024 14:29:56 +0100 Subject: [PATCH 4/5] Add test to ensure channel is shutdown on TLS dangling open --- Tests/HummingbirdCoreTests/TLSTests.swift | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/Tests/HummingbirdCoreTests/TLSTests.swift b/Tests/HummingbirdCoreTests/TLSTests.swift index 74bab8466..0b0a19cf5 100644 --- a/Tests/HummingbirdCoreTests/TLSTests.swift +++ b/Tests/HummingbirdCoreTests/TLSTests.swift @@ -16,6 +16,7 @@ import HummingbirdCore import HummingbirdTesting import HummingbirdTLS import Logging +import NIOConcurrencyHelpers import NIOCore import NIOPosix import NIOSSL @@ -39,4 +40,24 @@ class HummingBirdTLSTests: XCTestCase { XCTAssertEqual(body.readString(length: body.readableBytes), "Hello") } } + + func testGracefulShutdown() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let clientChannel: NIOLockedValueBox = .init(nil) + try await testServer( + responder: helloResponder, + httpChannelSetup: .tls(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: Logger(label: "Hummingbird") + ) { port in + let channel = try await ClientBootstrap(group: eventLoopGroup) + .connect(host: "127.0.0.1", port: port).get() + clientChannel.withLockedValue { $0 = channel } + } + // test channel has been closed + let channel = try clientChannel.withLockedValue { try XCTUnwrap($0) } + try await channel.closeFuture.get() + } } From e6daa2d649c9f18bc1831712de3eccaf0d0273ea Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 4 Jul 2024 14:32:44 +0100 Subject: [PATCH 5/5] Rename test --- Tests/HummingbirdCoreTests/TLSTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/HummingbirdCoreTests/TLSTests.swift b/Tests/HummingbirdCoreTests/TLSTests.swift index 0b0a19cf5..5ebc1b7bc 100644 --- a/Tests/HummingbirdCoreTests/TLSTests.swift +++ b/Tests/HummingbirdCoreTests/TLSTests.swift @@ -41,7 +41,7 @@ class HummingBirdTLSTests: XCTestCase { } } - func testGracefulShutdown() async throws { + func testGracefulShutdownWithDanglingConnection() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let clientChannel: NIOLockedValueBox = .init(nil)