Skip to content

Commit

Permalink
Finish multiplexer's inbound streams in more cases (#483)
Browse files Browse the repository at this point in the history
## Motivation
Currently, the multiplexer's inbound streams stream is finished only
when the channel becomes inactive.
There are some scenarios in which the channel may be closed before it
has a chance to become active, and the stream will never be finished.
This can cause any users iterating over the stream to hang.

## Modifications
This PR finishes the inbound streams stream when a connection error is
fired, and when the handler is removed.

## Result
Fewer bugs.
  • Loading branch information
gjcairo authored Nov 15, 2024
1 parent bb19976 commit 1879e72
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ extension NIOHTTP2Handler.InboundStreamMultiplexer {
}
}

func errorCaughtReceived(_ error: any Error) {
switch self {
case .inline(let inlineStreamMultiplexer):
inlineStreamMultiplexer.propagateErrorCaught(error)
case .legacy:
break // do nothing
}
}

func handlerRemovedReceived() {
switch self {
case .inline(let inlineStreamMultiplexer):
inlineStreamMultiplexer.propagateHandlerRemoved()
case .legacy:
break // do nothing
}
}

func processedFrame(_ frame: HTTP2Frame) {
switch self {
case .inline(let inlineStreamMultiplexer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ extension InlineStreamMultiplexer {
}
}

internal func propagateErrorCaught(_ error: any Error) {
self._commonStreamMultiplexer.propagateErrorCaught(error)
}

internal func propagateHandlerRemoved() {
self._commonStreamMultiplexer.propagateHandlerRemoved()
}

internal func processedFrame(frame: HTTP2Frame) {
self._commonStreamMultiplexer.processedFrame(streamID: frame.streamID, size: frame.payload.flowControlledSize)
}
Expand Down
6 changes: 6 additions & 0 deletions Sources/NIOHTTP2/HTTP2ChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ public final class NIOHTTP2Handler: ChannelDuplexHandler {
public func handlerRemoved(context: ChannelHandlerContext) {
// Any frames we're buffering need to be dropped.
self.outboundBuffer.invalidateBuffer()
self.inboundStreamMultiplexer?.handlerRemovedReceived()
self.inboundStreamMultiplexerState = .deinitialized
}

Expand Down Expand Up @@ -550,6 +551,11 @@ public final class NIOHTTP2Handler: ChannelDuplexHandler {
self.inboundStreamMultiplexer?.channelWritabilityChangedReceived()
context.fireChannelWritabilityChanged()
}

public func errorCaught(context: ChannelHandlerContext, error: any Error) {
self.inboundStreamMultiplexer?.errorCaughtReceived(error)
context.fireErrorCaught(error)
}
}

/// Inbound frame handling.
Expand Down
8 changes: 8 additions & 0 deletions Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ extension HTTP2CommonInboundStreamMultiplexer {
self.streamChannelContinuation?.finish()
}

internal func propagateErrorCaught(_ error: any Error) {
self.streamChannelContinuation?.finish(throwing: error)
}

internal func propagateHandlerRemoved() {
self.streamChannelContinuation?.finish()
}

internal func selectivelyPropagateUserInboundEvent(context: ChannelHandlerContext, event: Any) {
func propagateEvent(_ event: Any) {
for channel in self.streams.values {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,4 +503,67 @@ class SimpleClientServerInlineStreamMultiplexerTests: XCTestCase {
XCTAssertNoThrow(try self.clientChannel.finish())
XCTAssertNoThrow(try self.serverChannel.finish())
}

func testChannelInactiveFinishesAsyncStreamMultiplexerInboundStream() async throws {
let asyncClientChannel = NIOAsyncTestingChannel()
let asyncServerChannel = NIOAsyncTestingChannel()

// Setup the connection.
let clientMultiplexer = try await asyncClientChannel.configureAsyncHTTP2Pipeline(mode: .client) { _ in
asyncClientChannel.eventLoop.makeSucceededVoidFuture()
}.get()

let serverMultiplexer = try await asyncServerChannel.configureAsyncHTTP2Pipeline(mode: .server) { _ in
asyncServerChannel.eventLoop.makeSucceededVoidFuture()
}.get()

// Create the stream channel
let stream = try await clientMultiplexer.openStream { $0.eventLoop.makeSucceededFuture($0) }

// Initiate request to open the stream on the server.
let headers = HPACKHeaders([(":path", "/"), (":method", "POST"), (":scheme", "http")])
let frame: HTTP2Frame.FramePayload = .headers(.init(headers: headers))
stream.writeAndFlush(frame, promise: nil)
try await self.interactInMemory(asyncClientChannel, asyncServerChannel)

// Close server to fire channel inactive down the pipeline: it should be propagated.
try await asyncServerChannel.close()
for try await _ in serverMultiplexer.inbound {}
}

enum ErrorCaughtPropagated: Error, Equatable {
case error
}

func testErrorCaughtFinishesAsyncStreamMultiplexerInboundStream() async throws {
let asyncClientChannel = NIOAsyncTestingChannel()
let asyncServerChannel = NIOAsyncTestingChannel()

// Setup the connection.
let clientMultiplexer = try await asyncClientChannel.configureAsyncHTTP2Pipeline(mode: .client) { _ in
asyncClientChannel.eventLoop.makeSucceededVoidFuture()
}.get()

let serverMultiplexer = try await asyncServerChannel.configureAsyncHTTP2Pipeline(mode: .server) { _ in
asyncServerChannel.eventLoop.makeSucceededVoidFuture()
}.get()

// Create the stream channel
let stream = try await clientMultiplexer.openStream { $0.eventLoop.makeSucceededFuture($0) }

// Initiate request to open the stream on the server.
let headers = HPACKHeaders([(":path", "/"), (":method", "POST"), (":scheme", "http")])
let frame: HTTP2Frame.FramePayload = .headers(.init(headers: headers))
stream.writeAndFlush(frame, promise: nil)
try await self.interactInMemory(asyncClientChannel, asyncServerChannel)

// Fire an error down the server pipeline: it should cause the inbound stream to finish with error
asyncServerChannel.pipeline.fireErrorCaught(ErrorCaughtPropagated.error)
do {
for try await _ in serverMultiplexer.inbound {}
XCTFail("Expected error to be thrown")
} catch {
XCTAssert(error is ErrorCaughtPropagated)
}
}
}

0 comments on commit 1879e72

Please sign in to comment.