Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel on inbound close #627

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.build/
.swiftpm/
.vscode/
.index-build/
.devcontainer/
/Packages
/*.xcodeproj
Expand Down
95 changes: 95 additions & 0 deletions Sources/HummingbirdCore/Request/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
//===----------------------------------------------------------------------===//

import HTTPTypes
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTPTypes

/// Holds all the values required to process a request
public struct Request: Sendable {
Expand All @@ -32,6 +34,8 @@ public struct Request: Sendable {
@inlinable
public var headers: HTTPFields { self.head.headerFields }

private let iterationState: RequestIterationState?

// MARK: Initialization

/// Create new Request
Expand All @@ -45,6 +49,7 @@ public struct Request: Sendable {
self.uri = .init(head.path ?? "")
self.head = head
self.body = body
self.iterationState = .init()
}

/// Collapse body into one ByteBuffer.
Expand All @@ -60,10 +65,100 @@ public struct Request: Sendable {
self.body = .init(buffer: byteBuffer)
return byteBuffer
}

public func cancelOnInboundClose<Value: Sendable>(_ process: @escaping @Sendable (Request) async throws -> Value) async throws -> Value {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be part of the (Core)RequestContext instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server needs access to this object and it knows nothing about RequestContexts so not sure how we can do this.

guard let iterationState = self.iterationState else { return try await process(self) }
let iterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator>? =
switch self.body._backing {
case .nioAsyncChannelRequestBody(let iterator):
iterator.underlyingIterator
default:
nil
}
let (stream, source) = RequestBody.makeStream()
var request = self
request.body = stream
let newRequest = request
return try await iterationState.cancelOnIteratorFinished(iterator: iterator, source: source) {
try await process(newRequest)
}
}

package func getState() -> RequestIterationState.State? {
self.iterationState?.state.withLockedValue { $0 }
}
}

extension Request: CustomStringConvertible {
public var description: String {
"uri: \(self.uri), method: \(self.method), headers: \(self.headers), body: \(self.body)"
}
}

package struct RequestIterationState: Sendable {
fileprivate enum CancelOnInboundGroupType<Value: Sendable> {
case value(Value)
case done
}
package enum State: Sendable {
case idle
case processing
case nextHead(HTTPRequest)
case closed
}
let state: NIOLockedValueBox<State>

init() {
self.state = .init(.idle)
}

func cancelOnIteratorFinished<Value: Sendable>(
iterator: UnsafeTransfer<NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator>?,
source: RequestBody.Source,
process: @escaping @Sendable () async throws -> Value
) async throws -> Value {
let state = self.state.withLockedValue { $0 }
let unsafeSource = UnsafeTransfer(source)
switch (state, iterator) {
case (.idle, .some(let asyncIterator)):
self.state.withLockedValue { $0 = .processing }
return try await withThrowingTaskGroup(of: CancelOnInboundGroupType<Value>.self) { group in
group.addTask {
var asyncIterator = asyncIterator.wrappedValue
let source = unsafeSource.wrappedValue
while let part = try await asyncIterator.next() {
switch part {
case .head(let head):
self.state.withLockedValue { $0 = .nextHead(head) }
return .done
case .body(let buffer):
try await source.yield(buffer)
case .end:
source.finish()
}
}
throw CancellationError()
}
group.addTask {
try await .value(process())
}
do {
while let result = try await group.next() {
if case .value(let value) = result {
return value
}
}
} catch {
self.state.withLockedValue { $0 = .closed }
throw error
}
preconditionFailure("Cannot reach here")
}
case (.idle, .none), (.processing, _), (.nextHead, _):
return try await process()

case (.closed, _):
throw CancellationError()
}
}
}
2 changes: 1 addition & 1 deletion Sources/HummingbirdCore/Response/ResponseBodyWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import HTTPTypes
import NIOCore

/// HTTP Response Body part writer
public protocol ResponseBodyWriter {
public protocol ResponseBodyWriter: Sendable {
/// Write a single ByteBuffer
/// - Parameter buffer: single buffer to write
mutating func write(_ buffer: ByteBuffer) async throws
Expand Down
2 changes: 1 addition & 1 deletion Sources/HummingbirdCore/Response/ResponseWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import NIOCore
import NIOHTTPTypes

/// ResponseWriter that writes directly to AsyncChannel
public struct ResponseWriter: ~Copyable {
public struct ResponseWriter: ~Copyable, Sendable {
@usableFromInline
let outbound: NIOAsyncChannelOutboundWriter<HTTPResponsePart>

Expand Down
12 changes: 10 additions & 2 deletions Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ extension HTTPChannelHandler {
throw HTTPChannelError.unexpectedHTTPPart(part)
}

while true {
readParts: while true {
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream))
let responseWriter = ResponseWriter(outbound: outbound)
Expand All @@ -57,7 +57,15 @@ extension HTTPChannelHandler {
if request.headers[.connection] == "close" {
return
}

switch request.getState() {
case .nextHead(let newHead):
head = newHead
continue
case .closed:
break readParts
default:
break
}
// Flush current request
// read until we don't have a body part
var part: HTTPRequestPart?
Expand Down
14 changes: 11 additions & 3 deletions Sources/HummingbirdTesting/TestClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public struct TestClient: Sendable {
/// shutdown client
public func shutdown() async throws {
do {
try await self.close()
try await self.close(mode: .all)
} catch TestClient.Error.connectionNotOpen {
} catch ChannelError.alreadyClosed {}
if case .createNew = self.eventLoopGroupProvider {
Expand Down Expand Up @@ -163,10 +163,18 @@ public struct TestClient: Sendable {
return response
}

public func close() async throws {
/// Execute request to server but don't wait for response.
public func executeAndDontWaitForResponse(_ request: TestClient.Request) async throws {
let channel = try await getChannel()
let promise = self.eventLoopGroup.any().makePromise(of: TestClient.Response.self)
let task = HTTPTask(request: self.cleanupRequest(request), responsePromise: promise)
try await channel.writeAndFlush(task)
}

public func close(mode: CloseMode = .output) async throws {
self.channelPromise.completeWith(.failure(TestClient.Error.connectionNotOpen))
let channel = try await getChannel()
return try await channel.close()
return try await channel.close(mode: mode)
}

public func getChannel() async throws -> Channel {
Expand Down
51 changes: 51 additions & 0 deletions Tests/HummingbirdCoreTests/CoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,57 @@ final class HummingBirdCoreTests: XCTestCase {
await serviceGroup.triggerGracefulShutdown()
}
}

/// Test running cancel on inbound close without an inbound close
func testCancelOnCloseInboundWithoutClose() async throws {
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
let bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
try await request.cancelOnInboundClose { request in
var bodyWriter = bodyWriter
try await bodyWriter.write(request.body)
try await bodyWriter.finish(nil)
}
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
let response = try await client.get("/")
XCTAssertNil(response.body)
let response2 = try await client.post("/", body: ByteBuffer(string: "Hello"))
let body2 = try XCTUnwrap(response2.body)
XCTAssertEqual(String(buffer: body2), "Hello")
}
}

/// Test running cancel on inbound close actually cancels on inbound closure
func testCancelOnCloseInbound() async throws {
let handlerPromise = Promise<Void>()
try await testServer(
responder: { (request, responseWriter: consuming ResponseWriter, _) in
await handlerPromise.complete(())
let bodyWriter = try await responseWriter.writeHead(.init(status: .ok))
try await request.cancelOnInboundClose { request in
var bodyWriter2 = bodyWriter
let body = try await request.body.collect(upTo: .max)
while true {
try Task.checkCancellation()
try await bodyWriter2.write(body)
}
}
},
httpChannelSetup: .http1(),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
try await client.executeAndDontWaitForResponse(.init("/", method: .get))
await handlerPromise.wait()
try await client.close()
}
}
}

struct DelayAsyncSequence<CoreSequence: AsyncSequence>: AsyncSequence {
Expand Down
38 changes: 35 additions & 3 deletions Tests/HummingbirdCoreTests/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,39 @@ public func testServer<Value: Sendable>(
clientConfiguration: TestClient.Configuration = .init(),
test: @escaping @Sendable (TestClient) async throws -> Value
) async throws -> Value {
try await testServer(
try await withThrowingTaskGroup(of: Void.self) { group in
let promise = Promise<Int>()
let server = try httpChannelSetup.buildServer(
configuration: configuration,
eventLoopGroup: eventLoopGroup,
logger: logger,
responder: responder,
onServerRunning: { await promise.complete($0.localAddress!.port!) }
)
let serviceGroup = ServiceGroup(
configuration: .init(
services: [server],
gracefulShutdownSignals: [.sigterm, .sigint],
logger: logger
)
)

group.addTask {
try await serviceGroup.run()
}
let client = await TestClient(
host: "localhost",
port: promise.wait(),
configuration: clientConfiguration,
eventLoopGroupProvider: .createNew
)
client.connect()
let value = try await test(client)
try? await client.shutdown()
await serviceGroup.triggerGracefulShutdown()
return value
}
/* try await testServer(
responder: responder,
httpChannelSetup: httpChannelSetup,
configuration: configuration,
Expand All @@ -95,9 +127,9 @@ public func testServer<Value: Sendable>(
)
client.connect()
let value = try await test(client)
try await client.shutdown()
try? await client.shutdown()
return value
}
}*/
}

/// Run process with a timeout
Expand Down
Loading