diff --git a/Sources/HummingbirdHTTP2/HTTP2Channel.swift b/Sources/HummingbirdHTTP2/HTTP2Channel.swift index ebb933e2..1b7077b5 100644 --- a/Sources/HummingbirdHTTP2/HTTP2Channel.swift +++ b/Sources/HummingbirdHTTP2/HTTP2Channel.swift @@ -40,6 +40,8 @@ public struct HTTP2UpgradeChannel: ServerChildChannel { public var idleTimeout: Duration? /// Maximum amount of time to wait before all streams are closed after second GOAWAY public var gracefulCloseTimeout: Duration? + /// Maximum amount of time a connection can be open + public var maxAgeTimeout: Duration? /// Configuration applieds to HTTP2 stream channels public var streamConfiguration: HTTP1Channel.Configuration @@ -51,6 +53,7 @@ public struct HTTP2UpgradeChannel: ServerChildChannel { public init( idleTimeout: Duration? = nil, gracefulCloseTimeout: Duration? = nil, + maxAgeTimeout: Duration? = nil, streamConfiguration: HTTP1Channel.Configuration = .init() ) { self.idleTimeout = idleTimeout @@ -126,6 +129,7 @@ public struct HTTP2UpgradeChannel: ServerChildChannel { let connectionManager = HTTP2ServerConnectionManager( eventLoop: channel.eventLoop, idleTimeout: self.configuration.idleTimeout, + maxAgeTimeout: self.configuration.maxAgeTimeout, gracefulCloseTimeout: self.configuration.gracefulCloseTimeout ) let handler: HTTP2ConnectionOutput = try channel.pipeline.syncOperations.configureAsyncHTTP2Pipeline( diff --git a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift index 8b8e2bf7..57e9c79b 100644 --- a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift +++ b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift @@ -152,6 +152,29 @@ extension HTTP2ServerConnectionManager { return .none } } + + enum InputClosedResult { + case closeWithGoAway(lastStreamId: HTTP2StreamID) + case close + case none + } + + mutating func inputClosed() -> InputClosedResult { + switch self.state { + case .active(let activeState): + return .closeWithGoAway(lastStreamId: activeState.lastStreamId) + + case .closing(let closeState): + if closeState.sentSecondGoAway { + return .close + } else { + return .closeWithGoAway(lastStreamId: closeState.lastStreamId) + } + + case .closed: + return .none + } + } } } diff --git a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift index e2596900..83329d14 100644 --- a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift +++ b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift @@ -28,6 +28,8 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { var state: StateMachine /// Idle timer var idleTimer: Timer? + /// Maximum time a connection be open timer + var maxAgeTimer: Timer? /// Maximum amount of time we wait before closing the connection var gracefulCloseTimer: Timer? /// EventLoop connection manager running on @@ -39,12 +41,18 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { /// flush pending when read completes var flushPending: Bool - init(eventLoop: EventLoop, idleTimeout: Duration?, gracefulCloseTimeout: Duration?) { + init( + eventLoop: EventLoop, + idleTimeout: Duration?, + maxAgeTimeout: Duration?, + gracefulCloseTimeout: Duration? + ) { self.eventLoop = eventLoop self.state = .init() self.inReadLoop = false self.flushPending = false self.idleTimer = idleTimeout.map { Timer(delay: .init($0)) } + self.maxAgeTimer = maxAgeTimeout.map { Timer(delay: .init($0)) } self.gracefulCloseTimer = gracefulCloseTimeout.map { Timer(delay: .init($0)) } } @@ -54,6 +62,9 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { self.idleTimer?.schedule(on: self.eventLoop) { loopBoundHandler.triggerGracefulShutdown() } + self.maxAgeTimer?.schedule(on: self.eventLoop) { + loopBoundHandler.triggerGracefulShutdown() + } } func handlerRemoved(context: ChannelHandlerContext) { @@ -102,6 +113,8 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { switch event { case is ChannelShouldQuiesceEvent: self.triggerGracefulShutdown(context: context) + case let channelEvent as ChannelEvent where channelEvent == .inputClosed: + self.handleInputClosed(context: context) default: break } @@ -160,6 +173,12 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { if close { context.close(promise: nil) + } else { + // Setup grace period for closing. Close the connection abruptly once the grace period passes. + let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop) + self.gracefulCloseTimer?.schedule(on: context.eventLoop) { + loopBound.value.close(promise: nil) + } } case .none: break @@ -182,11 +201,29 @@ final class HTTP2ServerConnectionManager: ChannelDuplexHandler { context.write(self.wrapOutboundOut(ping), promise: nil) self.optionallyFlush(context: context) - // Setup grace period for closing. Close the connection abruptly once the grace period passes. - let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop) - self.gracefulCloseTimer?.schedule(on: context.eventLoop) { - loopBound.value.close(promise: nil) - } + case .none: + break + } + } + + func handleInputClosed(context: ChannelHandlerContext) { + switch self.state.inputClosed() { + case .closeWithGoAway(let lastStreamId): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: lastStreamId, + errorCode: .connectError, + opaqueData: context.channel.allocator.buffer(string: "input_closed") + ) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.optionallyFlush(context: context) + context.close(promise: nil) + + case .close: + context.close(promise: nil) case .none: break diff --git a/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift b/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift index c120cdc3..91bf96c7 100644 --- a/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift +++ b/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift @@ -55,7 +55,7 @@ final class HummingBirdHTTP2Tests: XCTestCase { } } - func testTwoRequests() async throws { + func testMultipleSerialRequests() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } var logger = Logger(label: "Hummingbird") @@ -76,16 +76,51 @@ final class HummingBirdHTTP2Tests: XCTestCase { logger: logger ) { port in let request = HTTPClientRequest(url: "https://localhost:\(port)/") - let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) - let response2 = try await httpClient.execute(request, deadline: .now() + .seconds(30)) - _ = try await response.body.collect(upTo: .max) - _ = try await response2.body.collect(upTo: .max) - XCTAssertEqual(response.status, .ok) + for _ in 0..<16 { + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + _ = try await response.body.collect(upTo: .max) + XCTAssertEqual(response.status, .ok) + } + } + } + } + + func testMultipleConcurrentRequests() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + try await withHTTPClient(.init(tlsConfiguration: tlsConfiguration)) { httpClient in + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + let request = HTTPClientRequest(url: "https://localhost:\(port)/") + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for _ in 0..<16 { + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + _ = try await response.body.collect(upTo: .max) + XCTAssertEqual(response.status, .ok) + } + } + try await group.waitForAll() + } } } } - func testGracefulTimeout() async throws { + func testConnectionClosed() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } var logger = Logger(label: "Hummingbird") @@ -96,11 +131,7 @@ final class HummingBirdHTTP2Tests: XCTestCase { try await responseWriter.writeResponse(.init(status: .ok)) }, httpChannelSetup: .http2Upgrade( - tlsConfiguration: getServerTLSConfiguration(), - configuration: .init( - gracefulCloseTimeout: .seconds(1), - streamConfiguration: .init(idleTimeout: .seconds(1)) - ) + tlsConfiguration: getServerTLSConfiguration() ), configuration: .init(address: .hostname(port: 0), serverName: testServerName), eventLoopGroup: eventLoopGroup,