Skip to content

Commit

Permalink
http2: prevent enqueuing OutStreams multiple times (#3893)
Browse files Browse the repository at this point in the history
Fixes #3890
  • Loading branch information
jrudolph committed Aug 5, 2021
1 parent 422e635 commit 5f8fb8c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Internals
ProblemFilters.exclude[MissingClassProblem]("akka.http.impl.engine.http2.PullFrameResult$NothingToSend$")
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ package object ccompat {
// in scala-library so we can't add to it
type IterableOnce[+X] = c.TraversableOnce[X]
val IterableOnce = c.TraversableOnce

implicit class RichQueue[T](val queue: mutable.Queue[T]) extends AnyVal {
// missing in 2.12
def -=(element: T): Unit = queue.dequeueAll(_ == element)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package akka.http.impl.engine.http2

import akka.http.ccompat._

import akka.annotation.InternalApi
import akka.event.LoggingAdapter
import akka.http.impl.engine.http2.FrameEvent._
Expand Down Expand Up @@ -45,7 +47,6 @@ private[http2] trait Http2Multiplexer {
private[http2] sealed abstract class PullFrameResult
@InternalApi
private[http2] object PullFrameResult {
final case object NothingToSend extends PullFrameResult
final case class SendFrame(frame: DataFrame, hasMore: Boolean) extends PullFrameResult
final case class SendFrameAndTrailer(frame: DataFrame, trailer: FrameEvent) extends PullFrameResult
}
Expand Down Expand Up @@ -134,6 +135,12 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage

private val controlFrameBuffer: mutable.Queue[FrameEvent] = new mutable.Queue[FrameEvent]
private val sendableOutstreams: mutable.Queue[Int] = new mutable.Queue[Int]
private def enqueueStream(streamId: Int): Unit = {
if (isDebugEnabled) require(!sendableOutstreams.contains(streamId), s"Stream [$streamId] was enqueued multiple times.") // requires expensive scanning -> avoid in production
sendableOutstreams.enqueue(streamId)
}
private def dequeueStream(streamId: Int): Unit =
sendableOutstreams -= streamId

private def updateState(transition: MultiplexerState => MultiplexerState): Unit = {
val oldState = _state
Expand Down Expand Up @@ -167,7 +174,7 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
case PullFrameResult.SendFrame(frame, hasMore) =>
send(frame)
if (hasMore) {
sendableOutstreams += streamId
enqueueStream(streamId)
WaitingForNetworkToSendData
} else {
if (sendableOutstreams.isEmpty) Idle
Expand All @@ -177,13 +184,6 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
send(frame)
controlFrameBuffer += trailer
WaitingForNetworkToSendControlFrames
case PullFrameResult.NothingToSend =>
// We are pulled but the stream that wanted to send, now chose otherwise.
// This can happen if either the stream got closed in the meantime, or if the stream was added to the queue
// multiple times, which can happen because `enqueueOutStream` is supposed to be idempotent but we don't check
// if we added an element several times to the queue (because it's inefficient).
if (sendableOutstreams.isEmpty) WaitingForData
else WaitingForNetworkToSendData.onPull()
}
}
}
Expand All @@ -203,7 +203,7 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
}
def connectionWindowAvailable(): MultiplexerState = this
def enqueueOutStream(streamId: Int): MultiplexerState = {
sendableOutstreams += streamId
enqueueStream(streamId)
WaitingForNetworkToSendData
}
def closeStream(streamId: Int): MultiplexerState = this
Expand All @@ -218,7 +218,7 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
def connectionWindowAvailable(): MultiplexerState = this // nothing to do, as there is no data to send
def enqueueOutStream(streamId: Int): MultiplexerState =
if (connectionWindowLeft == 0) {
sendableOutstreams += streamId
enqueueStream(streamId)
WaitingForConnectionWindow
} else sendDataFrame(streamId)
def closeStream(streamId: Int): MultiplexerState = this
Expand All @@ -239,12 +239,13 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
}
def connectionWindowAvailable(): MultiplexerState = this
def enqueueOutStream(streamId: Int): MultiplexerState = {
sendableOutstreams += streamId
enqueueStream(streamId)
this
}

def closeStream(streamId: Int): MultiplexerState = {
// leave stream in sendableOutstreams, to be skipped in sendDataFrame
// expensive operation, but only called for cancelled streams
dequeueStream(streamId)
this
}
}
Expand All @@ -256,15 +257,17 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
else {
val chosenId = prioritizer.chooseSubstream(sendableOutstreams.toSet)
// expensive operation, to be optimized when prioritizers can be configured
// in 2.12.x there's no Queue.-=, when 2.12.x support is dropped, this can be
// changed to Queue.-=
sendableOutstreams.dequeueAll(_ == chosenId)
dequeueStream(chosenId)
sendDataFrame(chosenId)
}

def closeStream(streamId: Int): MultiplexerState =
// leave stream in sendableOutstreams, to be skipped in sendDataFrame
this
def closeStream(streamId: Int): MultiplexerState = {
// expensive operation, but only called for cancelled streams
dequeueStream(streamId)
if (sendableOutstreams.nonEmpty) this
else if (pulled) WaitingForData
else Idle
}

def pulled: Boolean
}
Expand All @@ -281,7 +284,7 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
}
def connectionWindowAvailable(): MultiplexerState = this
def enqueueOutStream(streamId: Int): MultiplexerState = {
sendableOutstreams += streamId
enqueueStream(streamId)
this
}

Expand All @@ -297,7 +300,7 @@ private[http2] trait Http2MultiplexerSupport { logic: GraphStageLogic with Stage
}
def connectionWindowAvailable(): MultiplexerState = sendNext()
def enqueueOutStream(streamId: Int): MultiplexerState = {
sendableOutstreams += streamId
enqueueStream(streamId)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
def receivedUnexpectedFrame(e: StreamFrameEvent): StreamState = {
debug(s"Received unexpected frame of type ${e.frameTypeName} for stream ${e.streamId} in state $stateName")
pushGOAWAY(ErrorCode.PROTOCOL_ERROR, s"Received unexpected frame of type ${e.frameTypeName} for stream ${e.streamId} in state $stateName")
multiplexer.closeStream(e.streamId)
shutdown()
Closed
}

Expand Down Expand Up @@ -273,7 +273,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
nextStateStream(buffer)
}

def pullNextFrame(maxSize: Int): (StreamState, PullFrameResult) = (this, PullFrameResult.NothingToSend)
def pullNextFrame(maxSize: Int): (StreamState, PullFrameResult) = throw new IllegalStateException(s"pullNextFrame not supported in state $stateName")
def incomingStreamPulled(): StreamState = throw new IllegalStateException(s"incomingStreamPulled not supported in state $stateName")

/** Called to cleanup any state when the connection is torn down */
Expand Down Expand Up @@ -330,24 +330,23 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
trait Sending extends StreamState { _: Product =>
protected def outStream: OutStream

override def pullNextFrame(maxSize: Int): (StreamState, PullFrameResult) =
if (outStream.canSend) {
val frame = outStream.nextFrame(maxSize)
override def pullNextFrame(maxSize: Int): (StreamState, PullFrameResult) = {
val frame = outStream.nextFrame(maxSize)

val res =
outStream.endStreamIfPossible() match {
case Some(trailer) =>
PullFrameResult.SendFrameAndTrailer(frame, trailer)
case None =>
PullFrameResult.SendFrame(frame, outStream.canSend)
}
val res =
outStream.endStreamIfPossible() match {
case Some(trailer) =>
PullFrameResult.SendFrameAndTrailer(frame, trailer)
case None =>
PullFrameResult.SendFrame(frame, outStream.canSend)
}

val nextState =
if (outStream.isDone) handleOutgoingEnded()
else this
val nextState =
if (outStream.isDone) handleOutgoingEnded()
else this

(nextState, res)
} else (this, PullFrameResult.NothingToSend)
(nextState, res)
}

def handleWindowUpdate(windowUpdate: WindowUpdateFrame): StreamState = increaseWindow(windowUpdate.windowSizeIncrement)

Expand All @@ -371,7 +370,6 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
case w: WindowUpdateFrame =>
handleWindowUpdate(w)
case r: RstStreamFrame =>
multiplexer.closeStream(r.streamId)
outStream.cancelStream()
Closed
case _ =>
Expand Down Expand Up @@ -463,7 +461,6 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
override protected def onRstStreamFrame(rstStreamFrame: RstStreamFrame): Unit = {
super.onRstStreamFrame(rstStreamFrame)
outStream.cancelStream()
multiplexer.closeStream(rstStreamFrame.streamId)
}
override def incrementWindow(delta: Int): StreamState = {
outStream.increaseWindow(delta)
Expand All @@ -480,9 +477,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
case class HalfClosedRemoteWaitingForOutgoingStream(extraInitialWindow: Int) extends StreamState {
// FIXME: DRY with below
override def handle(event: StreamFrameEvent): StreamState = event match {
case r: RstStreamFrame =>
multiplexer.closeStream(r.streamId)
Closed
case r: RstStreamFrame => Closed
case w: WindowUpdateFrame => copy(extraInitialWindow = extraInitialWindow + w.windowSizeIncrement)
case _ => receivedUnexpectedFrame(event)
}
Expand All @@ -500,7 +495,6 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
def handle(event: StreamFrameEvent): StreamState = event match {
case r: RstStreamFrame =>
outStream.cancelStream()
multiplexer.closeStream(r.streamId)
Closed
case w: WindowUpdateFrame => handleWindowUpdate(w)
case _ => receivedUnexpectedFrame(event)
Expand Down Expand Up @@ -664,12 +658,19 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper

private var buffer: ByteString = ByteString.empty
private var upstreamClosed: Boolean = false
private var isEnqueued: Boolean = false
var endStreamSent: Boolean = false

/** Designates whether nextFrame can be called to get the next frame. */
def canSend: Boolean = buffer.nonEmpty && outboundWindowLeft > 0
def isDone: Boolean = endStreamSent

def enqueueIfPossible(): Unit =
if (canSend && !isEnqueued) {
isEnqueued = true
multiplexer.enqueueOutStream(streamId)
}

def registerIncomingData(inlet: SubSinkInlet[_]): Unit = {
require(!maybeInlet.isDefined)

Expand All @@ -681,7 +682,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
require(buffer.isEmpty)
buffer = data
upstreamClosed = true
if (canSend) multiplexer.enqueueOutStream(streamId)
enqueueIfPossible()
}

def nextFrame(maxBytesToSend: Int): DataFrame = {
Expand All @@ -700,6 +701,10 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper

debug(s"[$streamId] sending ${toSend.length} bytes, endStream = $endStream, remaining buffer [${buffer.length}], remaining stream-level WINDOW [$outboundWindowLeft]")

// Multiplexer will enqueue for us if we report more data being available
// We cannot call `multiplexer.enqueueOutStream` from here because this is called from the multiplexer.
isEnqueued = !isDone && canSend

DataFrame(streamId, endStream, toSend)
}

Expand Down Expand Up @@ -730,13 +735,16 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
}
}

def cancelStream(): Unit = cleanupStream()
def cancelStream(): Unit = {
cleanupStream()
if (isEnqueued) multiplexer.closeStream(streamId)
}
def bufferedBytes: Int = buffer.length

override def increaseWindow(increment: Int): Unit = if (increment >= 0) {
outboundWindowLeft += increment
debug(s"Updating window for $streamId by $increment to $outboundWindowLeft buffered bytes: $bufferedBytes")
if (canSend) multiplexer.enqueueOutStream(streamId)
enqueueIfPossible()
}

// external callbacks, need to make sure that potential stream state changing events are run through the state machine
Expand All @@ -753,7 +761,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper
maybePull()

// else wait for more data being drained
if (canSend) multiplexer.enqueueOutStream(streamId) // multiplexer might call pullNextFrame which goes through the state machine => OK
enqueueIfPossible() // multiplexer might call pullNextFrame which goes through the state machine => OK
}

override def onUpstreamFinish(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,21 @@ class Http2ServerSpec extends AkkaSpecWithMaterializer("""
entityDataOut.sendComplete()
network.expectDATA(TheStreamId, endStream = true, ByteString.empty)
}
"keep sending entity data when data is chunked into small bits" inAssertAllStagesStopped new WaitingForResponseDataSetup {
val data1 = ByteString("a")
val numChunks = 10000 // on my machine it crashed at ~1700
(1 to numChunks).foreach(_ => entityDataOut.sendNext(data1))

val (false, data) = network.expectDATAFrame(TheStreamId)
data shouldEqual ByteString("a" * numChunks)

// now don't fail if there's demand on the line
network.plainDataProbe.request(1)
network.expectNoBytes(100.millis)

entityDataOut.sendComplete()
network.expectDATA(TheStreamId, endStream = true, ByteString.empty)
}

"parse priority frames" inAssertAllStagesStopped new WaitingForResponseDataSetup {
network.sendPRIORITY(TheStreamId, true, 0, 5)
Expand Down Expand Up @@ -1061,6 +1076,54 @@ class Http2ServerSpec extends AkkaSpecWithMaterializer("""
// also complete stream 1
sendDataAndExpectOnNet(entity1DataOut, 1, "", endStream = true)
}
"receiving RST_STREAM for one of two sendable streams" inAssertAllStagesStopped new TestSetup with RequestResponseProbes {
val theRequest = HttpRequest(protocol = HttpProtocols.`HTTP/2.0`)
network.sendRequest(1, theRequest)
user.expectRequest() shouldBe theRequest

val entity1DataOut = TestPublisher.probe[ByteString]()
val response1 = HttpResponse(entity = HttpEntity(ContentTypes.`application/octet-stream`, Source.fromPublisher(entity1DataOut)))
user.emitResponse(1, response1)
network.expectDecodedHEADERS(streamId = 1, endStream = false) shouldBe response1.withEntity(HttpEntity.Empty.withContentType(ContentTypes.`application/octet-stream`))

def sendDataAndExpectOnNet(outStream: TestPublisher.Probe[ByteString], streamId: Int, dataString: String, endStream: Boolean = false): Unit = {
val data = ByteString(dataString)
if (dataString.nonEmpty) outStream.sendNext(data)
if (endStream) outStream.sendComplete()
if (data.nonEmpty || endStream) network.expectDATA(streamId, endStream = endStream, data)
}

sendDataAndExpectOnNet(entity1DataOut, 1, "abc")

// send second request
network.sendRequest(3, theRequest)
user.expectRequest() shouldBe theRequest

val entity2DataOut = TestPublisher.probe[ByteString]()
val response2 = HttpResponse(entity = HttpEntity(ContentTypes.`application/octet-stream`, Source.fromPublisher(entity2DataOut)))
user.emitResponse(3, response2)
network.expectDecodedHEADERS(streamId = 3, endStream = false) shouldBe response2.withEntity(HttpEntity.Empty.withContentType(ContentTypes.`application/octet-stream`))

// send again on stream 1
sendDataAndExpectOnNet(entity1DataOut, 1, "zyx")

// now send on stream 2
sendDataAndExpectOnNet(entity2DataOut, 3, "mnopq")

// now again on stream 1
sendDataAndExpectOnNet(entity1DataOut, 1, "jklm")

// send two data bits first but only pull and expect later
entity1DataOut.sendNext(ByteString("hihihi"))
entity2DataOut.sendNext(ByteString("hohoho"))

network.sendRST_STREAM(1, ErrorCode.CANCEL)

network.expectDATA(3, endStream = false, ByteString("hohoho"))

// last data of stream 2
sendDataAndExpectOnNet(entity2DataOut, 3, "uvwx", endStream = true)
}
"close substreams when connection is shutting down" inAssertAllStagesStopped StreamTestKit.assertAllStagesStopped(new TestSetup with RequestResponseProbes {
val requestEntity = HttpEntity.Empty.withContentType(ContentTypes.`application/octet-stream`)
val request = HttpRequest(entity = requestEntity)
Expand Down

0 comments on commit 5f8fb8c

Please sign in to comment.