Skip to content

Commit

Permalink
Revert to using quiescing helper
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed May 20, 2024
1 parent aed1e36 commit beeb780
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 78 deletions.
108 changes: 44 additions & 64 deletions Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,80 +32,60 @@ enum HTTPChannelError: Error {
case unexpectedHTTPPart(HTTPRequestPart)
}

enum HTTPState: Int, Sendable {
case idle
case processing
case cancelled
}

extension HTTPChannelHandler {
public func handleHTTP(asyncChannel: NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, logger: Logger) async {
let processingRequest = NIOLockedValueBox(HTTPState.idle)
do {
try await withTaskCancellationHandler {
try await withGracefulShutdownHandler {
try await asyncChannel.executeThenClose { inbound, outbound in
let responseWriter = HTTPServerBodyWriter(outbound: outbound)
var iterator = inbound.makeAsyncIterator()
try await asyncChannel.executeThenClose { inbound, outbound in
let responseWriter = HTTPServerBodyWriter(outbound: outbound)
var iterator = inbound.makeAsyncIterator()

// read first part, verify it is a head
guard let part = try await iterator.next() else { return }
guard case .head(var head) = part else {
throw HTTPChannelError.unexpectedHTTPPart(part)
}

// read first part, verify it is a head
guard let part = try await iterator.next() else { return }
guard case .head(var head) = part else {
throw HTTPChannelError.unexpectedHTTPPart(part)
while true {
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let response: Response
do {
response = try await self.responder(request, asyncChannel.channel)
} catch {
response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator)
}
do {
try await outbound.write(.head(response.head))
let tailHeaders = try await response.body.write(responseWriter)
try await outbound.write(.end(tailHeaders))
} catch {
throw error
}
if request.headers[.connection] == "close" {
return
}

// Flush current request
// read until we don't have a body part
var part: HTTPRequestPart?
while true {
// set to processing unless it is cancelled then exit
guard processingRequest.exchange(.processing) == .idle else { break }

let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let response: Response
do {
response = try await self.responder(request, asyncChannel.channel)
} catch {
response = self.getErrorResponse(from: error, allocator: asyncChannel.channel.allocator)
}
do {
try await outbound.write(.head(response.head))
let tailHeaders = try await response.body.write(responseWriter)
try await outbound.write(.end(tailHeaders))
} catch {
throw error
}
if request.headers[.connection] == "close" {
return
}
// set to idle unless it is cancelled then exit
guard processingRequest.exchange(.idle) == .processing else { break }

// Flush current request
// read until we don't have a body part
var part: HTTPRequestPart?
while true {
part = try await iterator.next()
guard case .body = part else { break }
}
// if we have an end then read the next part
if case .end = part {
part = try await iterator.next()
}

// if part is nil break out of loop
guard let part else {
break
}
part = try await iterator.next()
guard case .body = part else { break }
}
// if we have an end then read the next part
if case .end = part {
part = try await iterator.next()
}

// part should be a head, if not throw error
guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) }
head = newHead
// if part is nil break out of loop
guard let part else {
break
}
}
} onGracefulShutdown: {
// set to cancelled
if processingRequest.exchange(.cancelled) == .idle {
// only close the channel input if it is idle
asyncChannel.channel.close(mode: .input, promise: nil)

// part should be a head, if not throw error
guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) }
head = newHead
}
}
} onCancel: {
Expand Down
11 changes: 10 additions & 1 deletion Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelHandler
let part = unwrapOutboundIn(data)
if case .end = part {
self.requestsInProgress -= 1
context.write(data, promise: promise)
context.writeAndFlush(data, promise: promise)
if self.closeAfterResponseWritten {
context.close(promise: nil)
self.closeAfterResponseWritten = false
Expand All @@ -61,6 +61,15 @@ public class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelHandler

public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case is ChannelShouldQuiesceEvent:
// we received a quiesce event. If we have any requests in progress we should
// wait for them to finish
if self.requestsInProgress > 0 {
self.closeAfterResponseWritten = true
} else {
context.close(promise: nil)
}

case IdleStateHandler.IdleStateEvent.read:
// if we get an idle read event and we haven't completed reading the request
// close the connection
Expand Down
50 changes: 40 additions & 10 deletions Sources/HummingbirdCore/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
onServerRunning: (@Sendable (Channel) async -> Void)?
)
case starting
case running(asyncChannel: AsyncServerChannel)
case running(
asyncChannel: AsyncServerChannel,
quiescingHelper: ServerQuiescingHelper
)
case shuttingDown(shutdownPromise: EventLoopPromise<Void>)
case shutdown

Expand Down Expand Up @@ -96,7 +99,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
self.state = .starting

do {
let asyncChannel = try await self.makeServer(
let (asyncChannel, quiescingHelper) = try await self.makeServer(
childChannelSetup: childChannelSetup,
configuration: configuration
)
Expand All @@ -107,7 +110,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
fatalError("We should only be running once")

case .starting:
self.state = .running(asyncChannel: asyncChannel)
self.state = .running(asyncChannel: asyncChannel, quiescingHelper: quiescingHelper)

await withGracefulShutdownHandler {
await onServerRunning?(asyncChannel.channel)
Expand Down Expand Up @@ -138,13 +141,14 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
}

case .shuttingDown, .shutdown:
self.logger.info("Shutting down")
try await asyncChannel.channel.close()
}
} catch {
self.state = .shutdown
throw error
}
self.state = .shutdown

case .starting, .running:
fatalError("Run should only be called once")

Expand All @@ -162,10 +166,20 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
case .initial, .starting:
self.state = .shutdown

case .running(let channel):
case .running(let channel, let quiescingHelper):
let shutdownPromise = channel.channel.eventLoop.makePromise(of: Void.self)
channel.channel.close(promise: shutdownPromise)
self.state = .shuttingDown(shutdownPromise: shutdownPromise)
quiescingHelper.initiateShutdown(promise: shutdownPromise)
try await shutdownPromise.futureResult.get()

// We need to check the state here again since we just awaited above
switch self.state {
case .initial, .starting, .running, .shutdown:
fatalError("Unexpected state \(self.state)")

case .shuttingDown:
self.state = .shutdown
}

case .shuttingDown(let shutdownPromise):
// We are just going to queue up behind the current graceful shutdown
Expand All @@ -179,8 +193,8 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
/// Start server
/// - Parameter responder: Object that provides responses to requests sent to the server
/// - Returns: EventLoopFuture that is fulfilled when server has started
nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> AsyncServerChannel {
let bootstrap: ServerBootstrapProtocol
nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> (AsyncServerChannel, ServerQuiescingHelper) {
var bootstrap: ServerBootstrapProtocol
#if canImport(Network)
if let tsBootstrap = self.createTSBootstrap(configuration: configuration) {
bootstrap = tsBootstrap
Expand All @@ -199,6 +213,11 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
#endif

let quiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup)
bootstrap = bootstrap.serverChannelInitializer { channel in
channel.pipeline.addHandler(quiescingHelper.makeServerChannelHandler(channel: channel))
}

do {
switch configuration.address.value {
case .hostname(let host, let port):
Expand All @@ -213,7 +232,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
}
self.logger.info("Server started and listening on \(host):\(asyncChannel.channel.localAddress?.port ?? port)")
return asyncChannel
return (asyncChannel, quiescingHelper)

case .unixDomainSocket(let path):
let asyncChannel = try await bootstrap.bind(
Expand All @@ -227,7 +246,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
}
self.logger.info("Server started and listening on socket path \(path)")
return asyncChannel
return (asyncChannel, quiescingHelper)
}
} catch {
// should we close the channel here
Expand Down Expand Up @@ -272,6 +291,17 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {

/// Protocol for bootstrap.
protocol ServerBootstrapProtocol {
/// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add
/// `ChannelHandler`s to the `ChannelPipeline`.
///
/// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages.
///
/// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`.
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self

func bind<Output: Sendable>(
host: String,
port: Int,
Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdTesting/LiveTestFramework.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ final class LiveTestFramework<App: ApplicationProtocol>: ApplicationTestFramewor
client.connect()
do {
let value = try await test(Client(client: client))
await serviceGroup.triggerGracefulShutdown()
try await client.shutdown()
await serviceGroup.triggerGracefulShutdown()
return value
} catch {
await serviceGroup.triggerGracefulShutdown()
try await client.shutdown()
await serviceGroup.triggerGracefulShutdown()
throw error
}
}
Expand Down
4 changes: 3 additions & 1 deletion Tests/HummingbirdCoreTests/HTTP2Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class HummingBirdHTTP2Tests: XCTestCase {
func testConnect() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
var logger = Logger(label: "Hummingbird")
logger.logLevel = .trace
try await testServer(
responder: { _, _ in
.init(status: .ok)
},
httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()),
configuration: .init(address: .hostname(port: 0), serverName: testServerName),
eventLoopGroup: eventLoopGroup,
logger: Logger(label: "Hummingbird")
logger: logger
) { port in
var tlsConfiguration = try getClientTLSConfiguration()
// no way to override the SSL server name with AsyncHTTPClient so need to set
Expand Down

0 comments on commit beeb780

Please sign in to comment.