diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index 14e97715a..ba96b1988 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -32,80 +32,60 @@ enum HTTPChannelError: Error { case unexpectedHTTPPart(HTTPRequestPart) } -enum HTTPState: Int, Sendable { - case idle - case processing - case cancelled -} - extension HTTPChannelHandler { public func handleHTTP(asyncChannel: NIOAsyncChannel, logger: Logger) async { - let processingRequest = NIOLockedValueBox(HTTPState.idle) do { try await withTaskCancellationHandler { - try await withGracefulShutdownHandler { - try await asyncChannel.executeThenClose { inbound, outbound in - let responseWriter = HTTPServerBodyWriter(outbound: outbound) - var iterator = inbound.makeAsyncIterator() + try await asyncChannel.executeThenClose { inbound, outbound in + let responseWriter = HTTPServerBodyWriter(outbound: outbound) + var iterator = inbound.makeAsyncIterator() + + // read first part, verify it is a head + guard let part = try await iterator.next() else { return } + guard case .head(var head) = part else { + throw HTTPChannelError.unexpectedHTTPPart(part) + } - // read first part, verify it is a head - guard let part = try await iterator.next() else { return } - guard case .head(var head) = part else { - throw HTTPChannelError.unexpectedHTTPPart(part) + while true { + let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) + let request = Request(head: head, body: .init(asyncSequence: bodyStream)) + let response: Response + do { + response = try await self.responder(request, asyncChannel.channel) + } catch { + response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator) + } + do { + try await outbound.write(.head(response.head)) + let tailHeaders = try await response.body.write(responseWriter) + try await outbound.write(.end(tailHeaders)) + } catch { + throw error + } + if request.headers[.connection] == "close" { + return } + // Flush current request + // read until we don't have a body part + var part: HTTPRequestPart? while true { - // set to processing unless it is cancelled then exit - guard processingRequest.exchange(.processing) == .idle else { break } - - let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) - let request = Request(head: head, body: .init(asyncSequence: bodyStream)) - let response: Response - do { - response = try await self.responder(request, asyncChannel.channel) - } catch { - response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator) - } - do { - try await outbound.write(.head(response.head)) - let tailHeaders = try await response.body.write(responseWriter) - try await outbound.write(.end(tailHeaders)) - } catch { - throw error - } - if request.headers[.connection] == "close" { - return - } - // set to idle unless it is cancelled then exit - guard processingRequest.exchange(.idle) == .processing else { break } - - // Flush current request - // read until we don't have a body part - var part: HTTPRequestPart? - while true { - part = try await iterator.next() - guard case .body = part else { break } - } - // if we have an end then read the next part - if case .end = part { - part = try await iterator.next() - } - - // if part is nil break out of loop - guard let part else { - break - } + part = try await iterator.next() + guard case .body = part else { break } + } + // if we have an end then read the next part + if case .end = part { + part = try await iterator.next() + } - // part should be a head, if not throw error - guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } - head = newHead + // if part is nil break out of loop + guard let part else { + break } - } - } onGracefulShutdown: { - // set to cancelled - if processingRequest.exchange(.cancelled) == .idle { - // only close the channel input if it is idle - asyncChannel.channel.close(mode: .input, promise: nil) + + // part should be a head, if not throw error + guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) } + head = newHead } } } onCancel: { diff --git a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift index 2bf22b5a6..57309b8a8 100644 --- a/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift @@ -35,7 +35,7 @@ public class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelHandler let part = unwrapOutboundIn(data) if case .end = part { self.requestsInProgress -= 1 - context.write(data, promise: promise) + context.writeAndFlush(data, promise: promise) if self.closeAfterResponseWritten { context.close(promise: nil) self.closeAfterResponseWritten = false @@ -61,6 +61,15 @@ public class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelHandler public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { + case is ChannelShouldQuiesceEvent: + // we received a quiesce event. If we have any requests in progress we should + // wait for them to finish + if self.requestsInProgress > 0 { + self.closeAfterResponseWritten = true + } else { + context.close(promise: nil) + } + case IdleStateHandler.IdleStateEvent.read: // if we get an idle read event and we haven't completed reading the request // close the connection diff --git a/Sources/HummingbirdCore/Server/Server.swift b/Sources/HummingbirdCore/Server/Server.swift index 1f16242bb..0d64aabe6 100644 --- a/Sources/HummingbirdCore/Server/Server.swift +++ b/Sources/HummingbirdCore/Server/Server.swift @@ -33,7 +33,10 @@ public actor Server: Service { onServerRunning: (@Sendable (Channel) async -> Void)? ) case starting - case running(asyncChannel: AsyncServerChannel) + case running( + asyncChannel: AsyncServerChannel, + quiescingHelper: ServerQuiescingHelper + ) case shuttingDown(shutdownPromise: EventLoopPromise) case shutdown @@ -96,7 +99,7 @@ public actor Server: Service { self.state = .starting do { - let asyncChannel = try await self.makeServer( + let (asyncChannel, quiescingHelper) = try await self.makeServer( childChannelSetup: childChannelSetup, configuration: configuration ) @@ -107,7 +110,7 @@ public actor Server: Service { fatalError("We should only be running once") case .starting: - self.state = .running(asyncChannel: asyncChannel) + self.state = .running(asyncChannel: asyncChannel, quiescingHelper: quiescingHelper) await withGracefulShutdownHandler { await onServerRunning?(asyncChannel.channel) @@ -138,13 +141,14 @@ public actor Server: Service { } case .shuttingDown, .shutdown: + self.logger.info("Shutting down") try await asyncChannel.channel.close() } } catch { self.state = .shutdown throw error } - self.state = .shutdown + case .starting, .running: fatalError("Run should only be called once") @@ -162,10 +166,20 @@ public actor Server: Service { case .initial, .starting: self.state = .shutdown - case .running(let channel): + case .running(let channel, let quiescingHelper): let shutdownPromise = channel.channel.eventLoop.makePromise(of: Void.self) - channel.channel.close(promise: shutdownPromise) self.state = .shuttingDown(shutdownPromise: shutdownPromise) + quiescingHelper.initiateShutdown(promise: shutdownPromise) + try await shutdownPromise.futureResult.get() + + // We need to check the state here again since we just awaited above + switch self.state { + case .initial, .starting, .running, .shutdown: + fatalError("Unexpected state \(self.state)") + + case .shuttingDown: + self.state = .shutdown + } case .shuttingDown(let shutdownPromise): // We are just going to queue up behind the current graceful shutdown @@ -179,8 +193,8 @@ public actor Server: Service { /// Start server /// - Parameter responder: Object that provides responses to requests sent to the server /// - Returns: EventLoopFuture that is fulfilled when server has started - nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> AsyncServerChannel { - let bootstrap: ServerBootstrapProtocol + nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> (AsyncServerChannel, ServerQuiescingHelper) { + var bootstrap: ServerBootstrapProtocol #if canImport(Network) if let tsBootstrap = self.createTSBootstrap(configuration: configuration) { bootstrap = tsBootstrap @@ -199,6 +213,11 @@ public actor Server: Service { ) #endif + let quiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) + bootstrap = bootstrap.serverChannelInitializer { channel in + channel.pipeline.addHandler(quiescingHelper.makeServerChannelHandler(channel: channel)) + } + do { switch configuration.address.value { case .hostname(let host, let port): @@ -213,7 +232,7 @@ public actor Server: Service { ) } self.logger.info("Server started and listening on \(host):\(asyncChannel.channel.localAddress?.port ?? port)") - return asyncChannel + return (asyncChannel, quiescingHelper) case .unixDomainSocket(let path): let asyncChannel = try await bootstrap.bind( @@ -227,7 +246,7 @@ public actor Server: Service { ) } self.logger.info("Server started and listening on socket path \(path)") - return asyncChannel + return (asyncChannel, quiescingHelper) } } catch { // should we close the channel here @@ -272,6 +291,17 @@ public actor Server: Service { /// Protocol for bootstrap. protocol ServerBootstrapProtocol { + /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add + /// `ChannelHandler`s to the `ChannelPipeline`. + /// + /// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages. + /// + /// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`. + /// + /// - parameters: + /// - initializer: A closure that initializes the provided `Channel`. + func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self + func bind( host: String, port: Int, diff --git a/Sources/HummingbirdTesting/LiveTestFramework.swift b/Sources/HummingbirdTesting/LiveTestFramework.swift index a0d80acef..8bc6b853d 100644 --- a/Sources/HummingbirdTesting/LiveTestFramework.swift +++ b/Sources/HummingbirdTesting/LiveTestFramework.swift @@ -72,12 +72,12 @@ final class LiveTestFramework: ApplicationTestFramewor client.connect() do { let value = try await test(Client(client: client)) - await serviceGroup.triggerGracefulShutdown() try await client.shutdown() + await serviceGroup.triggerGracefulShutdown() return value } catch { - await serviceGroup.triggerGracefulShutdown() try await client.shutdown() + await serviceGroup.triggerGracefulShutdown() throw error } } diff --git a/Tests/HummingbirdCoreTests/HTTP2Tests.swift b/Tests/HummingbirdCoreTests/HTTP2Tests.swift index 81a8e2aca..3ab462de2 100644 --- a/Tests/HummingbirdCoreTests/HTTP2Tests.swift +++ b/Tests/HummingbirdCoreTests/HTTP2Tests.swift @@ -28,6 +28,8 @@ class HummingBirdHTTP2Tests: XCTestCase { func testConnect() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace try await testServer( responder: { _, _ in .init(status: .ok) @@ -35,7 +37,7 @@ class HummingBirdHTTP2Tests: XCTestCase { httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), configuration: .init(address: .hostname(port: 0), serverName: testServerName), eventLoopGroup: eventLoopGroup, - logger: Logger(label: "Hummingbird") + logger: logger ) { port in var tlsConfiguration = try getClientTLSConfiguration() // no way to override the SSL server name with AsyncHTTPClient so need to set